insurance_gam.anam
insurance_gam.anam — Actuarial Neural Additive Model subpackage.
Re-exports the full public API of the original insurance-anam package.
Requires the neural extra::
pip install insurance-gam[neural]
1""" 2insurance_gam.anam — Actuarial Neural Additive Model subpackage. 3 4Re-exports the full public API of the original insurance-anam package. 5 6Requires the ``neural`` extra:: 7 8 pip install insurance-gam[neural] 9""" 10 11try: 12 from .api import ANAM 13 from .feature_network import CategoricalFeatureNetwork, FeatureNetwork 14 from .interaction_network import InteractionNetwork 15 from .losses import ( 16 bernoulli_deviance, 17 gamma_deviance, 18 l1_sparsity_penalty, 19 l2_ridge_penalty, 20 poisson_deviance, 21 smoothness_penalty, 22 tweedie_deviance, 23 ) 24 from .model import ANAMModel, FeatureConfig, InteractionConfig 25 from .shapes import ShapeFunction, extract_shape_functions, plot_all_shapes 26 from .trainer import ANAMTrainer, TrainingConfig, TrainingHistory 27 from .utils import ( 28 StandardScaler, 29 compare_shapes_to_glm, 30 compute_deviance_stat, 31 select_interactions_correlation, 32 select_interactions_residual, 33 shapes_to_relativity_table, 34 ) 35except ImportError as _e: 36 raise ImportError( 37 f"insurance_gam.anam requires PyTorch. " 38 f"Install it with: pip install insurance-gam[neural]\n" 39 f"Original error: {_e}" 40 ) from _e 41 42__all__ = [ 43 "ANAM", 44 "ANAMModel", 45 "FeatureConfig", 46 "InteractionConfig", 47 "FeatureNetwork", 48 "CategoricalFeatureNetwork", 49 "InteractionNetwork", 50 "ANAMTrainer", 51 "TrainingConfig", 52 "TrainingHistory", 53 "poisson_deviance", 54 "gamma_deviance", 55 "tweedie_deviance", 56 "bernoulli_deviance", 57 "smoothness_penalty", 58 "l1_sparsity_penalty", 59 "l2_ridge_penalty", 60 "ShapeFunction", 61 "extract_shape_functions", 62 "plot_all_shapes", 63 "StandardScaler", 64 "select_interactions_correlation", 65 "select_interactions_residual", 66 "shapes_to_relativity_table", 67 "compare_shapes_to_glm", 68 "compute_deviance_stat", 69]
43class ANAM: 44 """sklearn-compatible Actuarial Neural Additive Model. 45 46 Parameters 47 ---------- 48 feature_configs: 49 List of FeatureConfig objects defining features. If None, all 50 features are treated as continuous with no constraints. 51 feature_names: 52 Feature names (used when feature_configs is None). Must match 53 columns of X passed to fit(). 54 categorical_features: 55 List of feature names that are categorical. Only used when 56 feature_configs is None. 57 monotone_increasing: 58 Feature names to constrain as monotone increasing. 59 monotone_decreasing: 60 Feature names to constrain as monotone decreasing. 61 link: 62 Link function: 'log' (Poisson/Tweedie), 'identity' (Gaussian), 63 'logit' (binary). 64 loss: 65 Distributional loss: 'poisson', 'tweedie', 'gamma', 'mse'. 66 tweedie_p: 67 Tweedie power parameter (only for loss='tweedie'). 68 interaction_pairs: 69 List of (feature_i, feature_j) tuples for interaction subnetworks. 70 hidden_sizes: 71 Default hidden layer sizes for subnetworks. 72 n_epochs: 73 Maximum training epochs. 74 batch_size: 75 Mini-batch size. 76 learning_rate: 77 Adam learning rate. 78 lambda_smooth: 79 Smoothness regularisation weight. 80 lambda_l2: 81 L2 ridge weight. 82 lambda_l1: 83 L1 sparsity weight. 84 patience: 85 Early stopping patience (epochs). 86 normalize: 87 If True, standardise continuous features before training. The 88 scaler is stored and applied automatically during predict(). 89 verbose: 90 Training verbosity. 0 = silent. 91 device: 92 'cpu', 'cuda', or None (auto). 93 """ 94 95 def __init__( 96 self, 97 feature_configs: Optional[List[FeatureConfig]] = None, 98 feature_names: Optional[List[str]] = None, 99 categorical_features: Optional[List[str]] = None, 100 monotone_increasing: Optional[List[str]] = None, 101 monotone_decreasing: Optional[List[str]] = None, 102 link: Literal["log", "identity", "logit"] = "log", 103 loss: Literal["poisson", "tweedie", "gamma", "mse"] = "poisson", 104 tweedie_p: float = 1.5, 105 interaction_pairs: Optional[List[Tuple[str, str]]] = None, 106 hidden_sizes: Optional[List[int]] = None, 107 n_epochs: int = 100, 108 batch_size: int = 512, 109 learning_rate: float = 1e-3, 110 lambda_smooth: float = 1e-4, 111 lambda_l2: float = 1e-4, 112 lambda_l1: float = 0.0, 113 patience: int = 15, 114 normalize: bool = True, 115 verbose: int = 0, 116 device: Optional[str] = None, 117 ) -> None: 118 self.feature_configs = feature_configs 119 self.feature_names = feature_names 120 # Store exactly as passed — no coercion to [] here. 121 # sklearn clone() checks that get_params() round-trips through __init__ 122 # exactly. Coercing None->[] would break that invariant. 123 # Defaults are applied at point-of-use in fit() and _build_feature_configs(). 124 self.categorical_features = categorical_features 125 self.monotone_increasing = monotone_increasing 126 self.monotone_decreasing = monotone_decreasing 127 self.link = link 128 self.loss = loss 129 self.tweedie_p = tweedie_p 130 self.interaction_pairs = interaction_pairs 131 self.hidden_sizes = hidden_sizes 132 self.n_epochs = n_epochs 133 self.batch_size = batch_size 134 self.learning_rate = learning_rate 135 self.lambda_smooth = lambda_smooth 136 self.lambda_l2 = lambda_l2 137 self.lambda_l1 = lambda_l1 138 self.patience = patience 139 self.normalize = normalize 140 self.verbose = verbose 141 self.device = device 142 143 # Set after fit() 144 self.model_: Optional[ANAMModel] = None 145 self.scaler_: Optional[StandardScaler] = None 146 self.history_: Optional[TrainingHistory] = None 147 self.feature_names_in_: Optional[List[str]] = None 148 self._continuous_col_indices: List[int] = [] 149 # P0-1 fix: track both the cached shapes and the n_points used to 150 # build them so we can detect stale cache on different n_points. 151 self._shapes_cache: Optional[Dict[str, ShapeFunction]] = None 152 self._shapes_cache_n_points: Optional[int] = None 153 self._X_train_scaled: Optional[np.ndarray] = None 154 # P0-2 fix: per-feature category remapping tables set during fit(). 155 # Maps feature name -> array where remap[original_code] = new_code. 156 self._cat_remap_: Dict[str, np.ndarray] = {} 157 158 def fit( 159 self, 160 X: Union[np.ndarray, pl.DataFrame], 161 y: Union[np.ndarray, pl.Series], 162 sample_weight: Optional[Union[np.ndarray, pl.Series]] = None, 163 ) -> "ANAM": 164 """Fit the ANAM model. 165 166 Parameters 167 ---------- 168 X: 169 Feature matrix, shape (n, p). Polars DataFrame accepted. 170 y: 171 Target vector (claim counts, rates, or severities). 172 sample_weight: 173 Exposure weights. For frequency models this is policy duration 174 (e.g. years on risk). If None, uniform exposure assumed. 175 176 Returns 177 ------- 178 self 179 """ 180 X_arr, y_arr, w_arr, names = self._validate_input(X, y, sample_weight) 181 self.feature_names_in_ = names 182 183 # Build feature configs if not provided. Also populates self._cat_remap_ 184 # for any categorical columns that are not zero-indexed consecutive ints. 185 feat_configs = self.feature_configs or self._build_feature_configs(names, X_arr) 186 187 # P0-2 fix: apply category remapping to training data so that the 188 # trainer and scaler see zero-indexed codes, matching the embedding 189 # table sizes that _build_feature_configs computed. 190 if self._cat_remap_: 191 X_arr = self._apply_cat_remap(X_arr) 192 193 # Identify continuous column indices for normalisation 194 self._continuous_col_indices = [ 195 i for i, cfg in enumerate(feat_configs) 196 if cfg.feature_type == "continuous" 197 ] 198 199 # Normalise continuous features 200 if self.normalize and self._continuous_col_indices: 201 self.scaler_ = StandardScaler() 202 cont_cols = np.array(self._continuous_col_indices) 203 X_arr_norm = X_arr.copy() 204 X_arr_norm[:, cont_cols] = self.scaler_.fit( 205 X_arr[:, cont_cols], feature_names=[names[i] for i in cont_cols] 206 ).transform(X_arr[:, cont_cols]) 207 else: 208 X_arr_norm = X_arr 209 210 self._X_train_scaled = X_arr_norm 211 212 # Build interaction configs 213 interaction_configs: List[InteractionConfig] = [] 214 for fi, fj in (self.interaction_pairs or []): 215 interaction_configs.append( 216 InteractionConfig(feature_i=fi, feature_j=fj) 217 ) 218 219 # Build model 220 self.model_ = ANAMModel( 221 feature_configs=feat_configs, 222 link=self.link, 223 interaction_configs=interaction_configs, 224 hidden_sizes=self.hidden_sizes or [64, 32], 225 ) 226 227 # Build trainer 228 train_cfg = TrainingConfig( 229 loss=self.loss, 230 tweedie_p=self.tweedie_p, 231 n_epochs=self.n_epochs, 232 batch_size=self.batch_size, 233 learning_rate=self.learning_rate, 234 lambda_smooth=self.lambda_smooth, 235 lambda_l2=self.lambda_l2, 236 lambda_l1=self.lambda_l1, 237 patience=self.patience, 238 verbose=self.verbose, 239 device=self.device, 240 ) 241 trainer = ANAMTrainer(self.model_, train_cfg) 242 self.history_ = trainer.fit(X_arr_norm, y_arr, exposure=w_arr) 243 self.model_ = trainer.model # may have moved to device 244 245 return self 246 247 def predict( 248 self, 249 X: Union[np.ndarray, pl.DataFrame], 250 exposure: Optional[Union[np.ndarray, pl.Series]] = None, 251 ) -> np.ndarray: 252 """Predict expected mean (mu) for each observation. 253 254 Parameters 255 ---------- 256 X: 257 Feature matrix. 258 exposure: 259 Exposure for each row. If None, exposure=1 assumed (predicts 260 per-policy-year rate for log-link models). 261 262 Returns 263 ------- 264 np.ndarray 265 Predicted means, shape (n,). 266 """ 267 self._check_fitted() 268 269 X_arr = self._to_array(X) 270 X_arr = self._apply_cat_remap(X_arr) 271 X_arr = self._apply_scaling(X_arr) 272 273 if exposure is not None: 274 if isinstance(exposure, pl.Series): 275 exposure = exposure.to_numpy() 276 exp_arr = np.asarray(exposure, dtype=np.float32) 277 log_exp = torch.tensor(np.log(exp_arr.clip(min=1e-8)), dtype=torch.float32) 278 else: 279 log_exp = None 280 281 X_t = torch.tensor(X_arr, dtype=torch.float32) 282 283 self.model_.eval() 284 with torch.no_grad(): 285 mu = self.model_(X_t, log_exposure=log_exp) 286 287 return mu.cpu().numpy() 288 289 def score( 290 self, 291 X: Union[np.ndarray, pl.DataFrame], 292 y: Union[np.ndarray, pl.Series], 293 sample_weight: Optional[Union[np.ndarray, pl.Series]] = None, 294 ) -> float: 295 """Return negative mean deviance (higher = better). 296 297 Follows sklearn convention where score() returns a value where 298 higher is better. We negate deviance so that model.score() can 299 be maximised by hyperparameter search. 300 """ 301 self._check_fitted() 302 303 y_arr = np.asarray(y if not isinstance(y, pl.Series) else y.to_numpy(), dtype=np.float32) 304 y_pred = self.predict(X, exposure=sample_weight) 305 306 w = None 307 if sample_weight is not None: 308 w = np.asarray( 309 sample_weight if not isinstance(sample_weight, pl.Series) 310 else sample_weight.to_numpy(), 311 dtype=np.float32, 312 ) 313 314 dev = compute_deviance_stat( 315 y_arr, y_pred, exposure=w, 316 loss=self.loss, tweedie_p=self.tweedie_p 317 ) 318 return -dev 319 320 def shape_functions( 321 self, 322 n_points: int = 200, 323 category_labels: Optional[Dict[str, Dict[int, str]]] = None, 324 ) -> Dict[str, ShapeFunction]: 325 """Extract and cache shape functions for all features. 326 327 Returns a dict mapping feature name to ShapeFunction. Shape 328 functions are evaluated over the observed training data range. 329 330 The cache is keyed on n_points. Calling with a different n_points 331 than the cached value discards the stale cache and recomputes. 332 """ 333 self._check_fitted() 334 assert self._X_train_scaled is not None 335 336 # P0-1 fix: invalidate cache when n_points changes. 337 if self._shapes_cache is None or self._shapes_cache_n_points != n_points: 338 self._shapes_cache = extract_shape_functions( 339 self.model_, 340 self._X_train_scaled, 341 n_points=n_points, 342 category_labels=category_labels, 343 ) 344 self._shapes_cache_n_points = n_points 345 return self._shapes_cache 346 347 def feature_importance(self) -> pl.DataFrame: 348 """Feature importance as subnetwork weight norms. 349 350 Returns a Polars DataFrame with columns [feature, importance], 351 sorted by importance descending. 352 """ 353 self._check_fitted() 354 imp = self.model_.feature_importance() 355 return pl.DataFrame( 356 {"feature": list(imp.keys()), "importance": list(imp.values())} 357 ).sort("importance", descending=True) 358 359 def get_params(self, deep: bool = True) -> Dict[str, Any]: 360 """sklearn get_params for grid search compatibility. 361 362 Returns ALL constructor parameters — including structural ones 363 (feature_configs, interaction_pairs, etc.) — so that 364 sklearn.base.clone() can reconstruct a functionally equivalent 365 model. 366 """ 367 # P0-4 fix: include all constructor parameters, not just numerical ones. 368 return { 369 "feature_configs": self.feature_configs, 370 "feature_names": self.feature_names, 371 "categorical_features": self.categorical_features, 372 "monotone_increasing": self.monotone_increasing, 373 "monotone_decreasing": self.monotone_decreasing, 374 "link": self.link, 375 "loss": self.loss, 376 "tweedie_p": self.tweedie_p, 377 "interaction_pairs": self.interaction_pairs, 378 "hidden_sizes": self.hidden_sizes, 379 "n_epochs": self.n_epochs, 380 "batch_size": self.batch_size, 381 "learning_rate": self.learning_rate, 382 "lambda_smooth": self.lambda_smooth, 383 "lambda_l2": self.lambda_l2, 384 "lambda_l1": self.lambda_l1, 385 "patience": self.patience, 386 "normalize": self.normalize, 387 "verbose": self.verbose, 388 "device": self.device, 389 } 390 391 def set_params(self, **params: Any) -> "ANAM": 392 """sklearn set_params for grid search compatibility.""" 393 for key, val in params.items(): 394 setattr(self, key, val) 395 return self 396 397 # ------------------------------------------------------------------ 398 # Internal helpers 399 # ------------------------------------------------------------------ 400 401 def _check_fitted(self) -> None: 402 if self.model_ is None: 403 raise RuntimeError("Call fit() before predict() or score().") 404 405 def _to_array(self, X: Union[np.ndarray, pl.DataFrame]) -> np.ndarray: 406 if isinstance(X, pl.DataFrame): 407 return X.to_numpy() 408 return np.asarray(X, dtype=np.float64) 409 410 def _apply_scaling(self, X_arr: np.ndarray) -> np.ndarray: 411 if self.scaler_ is None or not self._continuous_col_indices: 412 return X_arr.astype(np.float32) 413 X_norm = X_arr.copy().astype(np.float32) 414 cont_cols = np.array(self._continuous_col_indices) 415 X_norm[:, cont_cols] = self.scaler_.transform(X_arr[:, cont_cols].astype(np.float64)).astype(np.float32) 416 return X_norm 417 418 def _apply_cat_remap(self, X_arr: np.ndarray) -> np.ndarray: 419 """Apply stored category remappings. 420 421 P0-2 fix: if any categorical features were remapped during fit 422 (because they were not zero-indexed consecutive integers), apply 423 the same remapping here. Called both during fit() (on training 424 data) and during predict() (on new data) so that codes always 425 align with the trained embedding table. 426 """ 427 if not self._cat_remap_: 428 return X_arr 429 if self.feature_names_in_ is None: 430 return X_arr 431 X_out = X_arr.copy() 432 for feat_name, remap in self._cat_remap_.items(): 433 if feat_name in self.feature_names_in_: 434 col_idx = self.feature_names_in_.index(feat_name) 435 codes = X_out[:, col_idx].astype(int) 436 X_out[:, col_idx] = remap[codes].astype(X_out.dtype) 437 return X_out 438 439 def _validate_input( 440 self, 441 X: Union[np.ndarray, pl.DataFrame], 442 y: Union[np.ndarray, pl.Series], 443 w: Optional[Union[np.ndarray, pl.Series]], 444 ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], List[str]]: 445 """Convert inputs to numpy and extract feature names.""" 446 if isinstance(X, pl.DataFrame): 447 names = X.columns 448 X_arr = X.to_numpy().astype(np.float64) 449 else: 450 X_arr = np.asarray(X, dtype=np.float64) 451 names = self.feature_names or [f"f{i}" for i in range(X_arr.shape[1])] 452 453 if isinstance(y, pl.Series): 454 y_arr = y.to_numpy().astype(np.float32) 455 else: 456 y_arr = np.asarray(y, dtype=np.float32) 457 458 if w is not None: 459 if isinstance(w, pl.Series): 460 w_arr: Optional[np.ndarray] = w.to_numpy().astype(np.float32) 461 else: 462 w_arr = np.asarray(w, dtype=np.float32) 463 else: 464 w_arr = None 465 466 return X_arr, y_arr, w_arr, list(names) 467 468 def _build_feature_configs( 469 self, names: List[str], X_arr: np.ndarray 470 ) -> List[FeatureConfig]: 471 """Auto-construct FeatureConfig list from names and constraints. 472 473 P0-2 fix: for categorical features, validate that column values are 474 zero-indexed consecutive integers. If not, auto-remap and warn. The 475 remap is stored in self._cat_remap_ and applied to the training 476 array in fit() before the scaler and trainer receive it. 477 """ 478 configs: List[FeatureConfig] = [] 479 self._cat_remap_ = {} 480 481 for i, name in enumerate(names): 482 if name in (self.categorical_features or []): 483 col = X_arr[:, i].astype(int) 484 unique_vals = np.unique(col) 485 expected = np.arange(len(unique_vals)) 486 if not np.array_equal(unique_vals, expected): 487 # Not zero-indexed consecutive integers. Auto-remap. 488 warnings.warn( 489 f"Categorical feature '{name}' has values {unique_vals.tolist()} " 490 f"which are not zero-indexed consecutive integers. " 491 f"Auto-remapping to 0..{len(unique_vals) - 1}. " 492 f"The same remapping will be applied at predict time. " 493 f"If you rely on specific category indices in post-hoc analysis, " 494 f"remap the column yourself before passing to fit().", 495 UserWarning, 496 stacklevel=3, 497 ) 498 # Build a lookup array: remap[original_code] = new_code. 499 max_orig = int(unique_vals.max()) 500 remap = np.full(max_orig + 1, -1, dtype=int) 501 for new_code, orig_code in enumerate(unique_vals): 502 remap[orig_code] = new_code 503 self._cat_remap_[name] = remap 504 col = remap[col] 505 506 n_cats = int(col.max()) + 1 507 configs.append( 508 FeatureConfig( 509 name=name, 510 feature_type="categorical", 511 n_categories=n_cats, 512 ) 513 ) 514 else: 515 mono: Literal["increasing", "decreasing", "none"] = "none" 516 if name in (self.monotone_increasing or []): 517 mono = "increasing" 518 elif name in (self.monotone_decreasing or []): 519 mono = "decreasing" 520 configs.append( 521 FeatureConfig( 522 name=name, 523 feature_type="continuous", 524 monotonicity=mono, 525 ) 526 ) 527 return configs
sklearn-compatible Actuarial Neural Additive Model.
Parameters
feature_configs: List of FeatureConfig objects defining features. If None, all features are treated as continuous with no constraints. feature_names: Feature names (used when feature_configs is None). Must match columns of X passed to fit(). categorical_features: List of feature names that are categorical. Only used when feature_configs is None. monotone_increasing: Feature names to constrain as monotone increasing. monotone_decreasing: Feature names to constrain as monotone decreasing. link: Link function: 'log' (Poisson/Tweedie), 'identity' (Gaussian), 'logit' (binary). loss: Distributional loss: 'poisson', 'tweedie', 'gamma', 'mse'. tweedie_p: Tweedie power parameter (only for loss='tweedie'). interaction_pairs: List of (feature_i, feature_j) tuples for interaction subnetworks. hidden_sizes: Default hidden layer sizes for subnetworks. n_epochs: Maximum training epochs. batch_size: Mini-batch size. learning_rate: Adam learning rate. lambda_smooth: Smoothness regularisation weight. lambda_l2: L2 ridge weight. lambda_l1: L1 sparsity weight. patience: Early stopping patience (epochs). normalize: If True, standardise continuous features before training. The scaler is stored and applied automatically during predict(). verbose: Training verbosity. 0 = silent. device: 'cpu', 'cuda', or None (auto).
95 def __init__( 96 self, 97 feature_configs: Optional[List[FeatureConfig]] = None, 98 feature_names: Optional[List[str]] = None, 99 categorical_features: Optional[List[str]] = None, 100 monotone_increasing: Optional[List[str]] = None, 101 monotone_decreasing: Optional[List[str]] = None, 102 link: Literal["log", "identity", "logit"] = "log", 103 loss: Literal["poisson", "tweedie", "gamma", "mse"] = "poisson", 104 tweedie_p: float = 1.5, 105 interaction_pairs: Optional[List[Tuple[str, str]]] = None, 106 hidden_sizes: Optional[List[int]] = None, 107 n_epochs: int = 100, 108 batch_size: int = 512, 109 learning_rate: float = 1e-3, 110 lambda_smooth: float = 1e-4, 111 lambda_l2: float = 1e-4, 112 lambda_l1: float = 0.0, 113 patience: int = 15, 114 normalize: bool = True, 115 verbose: int = 0, 116 device: Optional[str] = None, 117 ) -> None: 118 self.feature_configs = feature_configs 119 self.feature_names = feature_names 120 # Store exactly as passed — no coercion to [] here. 121 # sklearn clone() checks that get_params() round-trips through __init__ 122 # exactly. Coercing None->[] would break that invariant. 123 # Defaults are applied at point-of-use in fit() and _build_feature_configs(). 124 self.categorical_features = categorical_features 125 self.monotone_increasing = monotone_increasing 126 self.monotone_decreasing = monotone_decreasing 127 self.link = link 128 self.loss = loss 129 self.tweedie_p = tweedie_p 130 self.interaction_pairs = interaction_pairs 131 self.hidden_sizes = hidden_sizes 132 self.n_epochs = n_epochs 133 self.batch_size = batch_size 134 self.learning_rate = learning_rate 135 self.lambda_smooth = lambda_smooth 136 self.lambda_l2 = lambda_l2 137 self.lambda_l1 = lambda_l1 138 self.patience = patience 139 self.normalize = normalize 140 self.verbose = verbose 141 self.device = device 142 143 # Set after fit() 144 self.model_: Optional[ANAMModel] = None 145 self.scaler_: Optional[StandardScaler] = None 146 self.history_: Optional[TrainingHistory] = None 147 self.feature_names_in_: Optional[List[str]] = None 148 self._continuous_col_indices: List[int] = [] 149 # P0-1 fix: track both the cached shapes and the n_points used to 150 # build them so we can detect stale cache on different n_points. 151 self._shapes_cache: Optional[Dict[str, ShapeFunction]] = None 152 self._shapes_cache_n_points: Optional[int] = None 153 self._X_train_scaled: Optional[np.ndarray] = None 154 # P0-2 fix: per-feature category remapping tables set during fit(). 155 # Maps feature name -> array where remap[original_code] = new_code. 156 self._cat_remap_: Dict[str, np.ndarray] = {}
158 def fit( 159 self, 160 X: Union[np.ndarray, pl.DataFrame], 161 y: Union[np.ndarray, pl.Series], 162 sample_weight: Optional[Union[np.ndarray, pl.Series]] = None, 163 ) -> "ANAM": 164 """Fit the ANAM model. 165 166 Parameters 167 ---------- 168 X: 169 Feature matrix, shape (n, p). Polars DataFrame accepted. 170 y: 171 Target vector (claim counts, rates, or severities). 172 sample_weight: 173 Exposure weights. For frequency models this is policy duration 174 (e.g. years on risk). If None, uniform exposure assumed. 175 176 Returns 177 ------- 178 self 179 """ 180 X_arr, y_arr, w_arr, names = self._validate_input(X, y, sample_weight) 181 self.feature_names_in_ = names 182 183 # Build feature configs if not provided. Also populates self._cat_remap_ 184 # for any categorical columns that are not zero-indexed consecutive ints. 185 feat_configs = self.feature_configs or self._build_feature_configs(names, X_arr) 186 187 # P0-2 fix: apply category remapping to training data so that the 188 # trainer and scaler see zero-indexed codes, matching the embedding 189 # table sizes that _build_feature_configs computed. 190 if self._cat_remap_: 191 X_arr = self._apply_cat_remap(X_arr) 192 193 # Identify continuous column indices for normalisation 194 self._continuous_col_indices = [ 195 i for i, cfg in enumerate(feat_configs) 196 if cfg.feature_type == "continuous" 197 ] 198 199 # Normalise continuous features 200 if self.normalize and self._continuous_col_indices: 201 self.scaler_ = StandardScaler() 202 cont_cols = np.array(self._continuous_col_indices) 203 X_arr_norm = X_arr.copy() 204 X_arr_norm[:, cont_cols] = self.scaler_.fit( 205 X_arr[:, cont_cols], feature_names=[names[i] for i in cont_cols] 206 ).transform(X_arr[:, cont_cols]) 207 else: 208 X_arr_norm = X_arr 209 210 self._X_train_scaled = X_arr_norm 211 212 # Build interaction configs 213 interaction_configs: List[InteractionConfig] = [] 214 for fi, fj in (self.interaction_pairs or []): 215 interaction_configs.append( 216 InteractionConfig(feature_i=fi, feature_j=fj) 217 ) 218 219 # Build model 220 self.model_ = ANAMModel( 221 feature_configs=feat_configs, 222 link=self.link, 223 interaction_configs=interaction_configs, 224 hidden_sizes=self.hidden_sizes or [64, 32], 225 ) 226 227 # Build trainer 228 train_cfg = TrainingConfig( 229 loss=self.loss, 230 tweedie_p=self.tweedie_p, 231 n_epochs=self.n_epochs, 232 batch_size=self.batch_size, 233 learning_rate=self.learning_rate, 234 lambda_smooth=self.lambda_smooth, 235 lambda_l2=self.lambda_l2, 236 lambda_l1=self.lambda_l1, 237 patience=self.patience, 238 verbose=self.verbose, 239 device=self.device, 240 ) 241 trainer = ANAMTrainer(self.model_, train_cfg) 242 self.history_ = trainer.fit(X_arr_norm, y_arr, exposure=w_arr) 243 self.model_ = trainer.model # may have moved to device 244 245 return self
Fit the ANAM model.
Parameters
X: Feature matrix, shape (n, p). Polars DataFrame accepted. y: Target vector (claim counts, rates, or severities). sample_weight: Exposure weights. For frequency models this is policy duration (e.g. years on risk). If None, uniform exposure assumed.
Returns
self
247 def predict( 248 self, 249 X: Union[np.ndarray, pl.DataFrame], 250 exposure: Optional[Union[np.ndarray, pl.Series]] = None, 251 ) -> np.ndarray: 252 """Predict expected mean (mu) for each observation. 253 254 Parameters 255 ---------- 256 X: 257 Feature matrix. 258 exposure: 259 Exposure for each row. If None, exposure=1 assumed (predicts 260 per-policy-year rate for log-link models). 261 262 Returns 263 ------- 264 np.ndarray 265 Predicted means, shape (n,). 266 """ 267 self._check_fitted() 268 269 X_arr = self._to_array(X) 270 X_arr = self._apply_cat_remap(X_arr) 271 X_arr = self._apply_scaling(X_arr) 272 273 if exposure is not None: 274 if isinstance(exposure, pl.Series): 275 exposure = exposure.to_numpy() 276 exp_arr = np.asarray(exposure, dtype=np.float32) 277 log_exp = torch.tensor(np.log(exp_arr.clip(min=1e-8)), dtype=torch.float32) 278 else: 279 log_exp = None 280 281 X_t = torch.tensor(X_arr, dtype=torch.float32) 282 283 self.model_.eval() 284 with torch.no_grad(): 285 mu = self.model_(X_t, log_exposure=log_exp) 286 287 return mu.cpu().numpy()
Predict expected mean (mu) for each observation.
Parameters
X: Feature matrix. exposure: Exposure for each row. If None, exposure=1 assumed (predicts per-policy-year rate for log-link models).
Returns
np.ndarray Predicted means, shape (n,).
289 def score( 290 self, 291 X: Union[np.ndarray, pl.DataFrame], 292 y: Union[np.ndarray, pl.Series], 293 sample_weight: Optional[Union[np.ndarray, pl.Series]] = None, 294 ) -> float: 295 """Return negative mean deviance (higher = better). 296 297 Follows sklearn convention where score() returns a value where 298 higher is better. We negate deviance so that model.score() can 299 be maximised by hyperparameter search. 300 """ 301 self._check_fitted() 302 303 y_arr = np.asarray(y if not isinstance(y, pl.Series) else y.to_numpy(), dtype=np.float32) 304 y_pred = self.predict(X, exposure=sample_weight) 305 306 w = None 307 if sample_weight is not None: 308 w = np.asarray( 309 sample_weight if not isinstance(sample_weight, pl.Series) 310 else sample_weight.to_numpy(), 311 dtype=np.float32, 312 ) 313 314 dev = compute_deviance_stat( 315 y_arr, y_pred, exposure=w, 316 loss=self.loss, tweedie_p=self.tweedie_p 317 ) 318 return -dev
Return negative mean deviance (higher = better).
Follows sklearn convention where score() returns a value where higher is better. We negate deviance so that model.score() can be maximised by hyperparameter search.
320 def shape_functions( 321 self, 322 n_points: int = 200, 323 category_labels: Optional[Dict[str, Dict[int, str]]] = None, 324 ) -> Dict[str, ShapeFunction]: 325 """Extract and cache shape functions for all features. 326 327 Returns a dict mapping feature name to ShapeFunction. Shape 328 functions are evaluated over the observed training data range. 329 330 The cache is keyed on n_points. Calling with a different n_points 331 than the cached value discards the stale cache and recomputes. 332 """ 333 self._check_fitted() 334 assert self._X_train_scaled is not None 335 336 # P0-1 fix: invalidate cache when n_points changes. 337 if self._shapes_cache is None or self._shapes_cache_n_points != n_points: 338 self._shapes_cache = extract_shape_functions( 339 self.model_, 340 self._X_train_scaled, 341 n_points=n_points, 342 category_labels=category_labels, 343 ) 344 self._shapes_cache_n_points = n_points 345 return self._shapes_cache
Extract and cache shape functions for all features.
Returns a dict mapping feature name to ShapeFunction. Shape functions are evaluated over the observed training data range.
The cache is keyed on n_points. Calling with a different n_points than the cached value discards the stale cache and recomputes.
347 def feature_importance(self) -> pl.DataFrame: 348 """Feature importance as subnetwork weight norms. 349 350 Returns a Polars DataFrame with columns [feature, importance], 351 sorted by importance descending. 352 """ 353 self._check_fitted() 354 imp = self.model_.feature_importance() 355 return pl.DataFrame( 356 {"feature": list(imp.keys()), "importance": list(imp.values())} 357 ).sort("importance", descending=True)
Feature importance as subnetwork weight norms.
Returns a Polars DataFrame with columns [feature, importance], sorted by importance descending.
359 def get_params(self, deep: bool = True) -> Dict[str, Any]: 360 """sklearn get_params for grid search compatibility. 361 362 Returns ALL constructor parameters — including structural ones 363 (feature_configs, interaction_pairs, etc.) — so that 364 sklearn.base.clone() can reconstruct a functionally equivalent 365 model. 366 """ 367 # P0-4 fix: include all constructor parameters, not just numerical ones. 368 return { 369 "feature_configs": self.feature_configs, 370 "feature_names": self.feature_names, 371 "categorical_features": self.categorical_features, 372 "monotone_increasing": self.monotone_increasing, 373 "monotone_decreasing": self.monotone_decreasing, 374 "link": self.link, 375 "loss": self.loss, 376 "tweedie_p": self.tweedie_p, 377 "interaction_pairs": self.interaction_pairs, 378 "hidden_sizes": self.hidden_sizes, 379 "n_epochs": self.n_epochs, 380 "batch_size": self.batch_size, 381 "learning_rate": self.learning_rate, 382 "lambda_smooth": self.lambda_smooth, 383 "lambda_l2": self.lambda_l2, 384 "lambda_l1": self.lambda_l1, 385 "patience": self.patience, 386 "normalize": self.normalize, 387 "verbose": self.verbose, 388 "device": self.device, 389 }
sklearn get_params for grid search compatibility.
Returns ALL constructor parameters — including structural ones (feature_configs, interaction_pairs, etc.) — so that sklearn.base.clone() can reconstruct a functionally equivalent model.
85class ANAMModel(nn.Module): 86 """Actuarial Neural Additive Model. 87 88 Orchestrates one subnetwork per feature plus optional pairwise 89 interaction networks. The output is: 90 91 eta = bias + sum_i f_i(x_i) + sum_{(i,j)} g_{ij}(x_i, x_j) 92 mu = link_inverse(eta + log_offset) 93 94 Parameters 95 ---------- 96 feature_configs: 97 Ordered list of FeatureConfig objects. The column order in X arrays 98 passed to forward() must match this list. 99 link: 100 Link function. 'log' for Poisson/Tweedie/Gamma, 'identity' for 101 Gaussian, 'logit' for binary. 102 interaction_configs: 103 Optional list of InteractionConfig for pairwise terms. 104 hidden_sizes: 105 Default hidden sizes for all subnetworks (overridden per-feature 106 by FeatureConfig.hidden_sizes). 107 dropout: 108 Dropout rate applied within subnetworks. 109 """ 110 111 def __init__( 112 self, 113 feature_configs: List[FeatureConfig], 114 link: LinkFunction = "log", 115 interaction_configs: Optional[List[InteractionConfig]] = None, 116 hidden_sizes: Optional[List[int]] = None, 117 dropout: float = 0.0, 118 ) -> None: 119 super().__init__() 120 121 self.feature_configs = feature_configs 122 self.link = link 123 self.interaction_configs = interaction_configs or [] 124 125 # Map feature names to column indices 126 self.feature_name_to_idx: Dict[str, int] = { 127 cfg.name: i for i, cfg in enumerate(feature_configs) 128 } 129 130 default_hidden = hidden_sizes or [64, 32] 131 132 # Build per-feature subnetworks 133 feature_nets: Dict[str, nn.Module] = {} 134 for cfg in feature_configs: 135 net_hidden = cfg.hidden_sizes or default_hidden 136 if cfg.feature_type == "continuous": 137 feature_nets[cfg.name] = FeatureNetwork( 138 hidden_sizes=net_hidden, 139 monotonicity=cfg.monotonicity, 140 dropout=dropout, 141 ) 142 elif cfg.feature_type == "categorical": 143 if cfg.n_categories is None: 144 raise ValueError( 145 f"Feature '{cfg.name}' is categorical but n_categories is not set." 146 ) 147 feature_nets[cfg.name] = CategoricalFeatureNetwork( 148 n_categories=cfg.n_categories, 149 embedding_dim=cfg.embedding_dim, 150 hidden_sizes=net_hidden, 151 dropout=dropout, 152 ) 153 else: 154 raise ValueError(f"Unknown feature_type: {cfg.feature_type!r}") 155 156 self.feature_nets = nn.ModuleDict(feature_nets) 157 158 # Build interaction subnetworks 159 interaction_nets: Dict[str, nn.Module] = {} 160 for icfg in self.interaction_configs: 161 key = f"{icfg.feature_i}_x_{icfg.feature_j}" 162 i_idx = self.feature_name_to_idx[icfg.feature_i] 163 j_idx = self.feature_name_to_idx[icfg.feature_j] 164 i_hidden = icfg.hidden_sizes or [32, 16] 165 interaction_nets[key] = InteractionNetwork( 166 feature_indices=(i_idx, j_idx), 167 hidden_sizes=i_hidden, 168 dropout=dropout, 169 ) 170 171 self.interaction_nets = nn.ModuleDict(interaction_nets) 172 173 # Scalar bias (learnable) 174 self.bias = nn.Parameter(torch.zeros(1)) 175 176 def forward( 177 self, 178 X: torch.Tensor, 179 log_exposure: Optional[torch.Tensor] = None, 180 ) -> torch.Tensor: 181 """Forward pass through all subnetworks. 182 183 Parameters 184 ---------- 185 X: 186 Feature matrix, shape (batch, n_features). Continuous features 187 should be pre-normalised. Categorical features should be integer 188 indices (will be cast to long internally). 189 log_exposure: 190 Log of exposure (e.g. log policy duration in years), shape 191 (batch,). Added as offset to the linear predictor before the 192 link function. If None, no offset applied (equivalent to 193 exposure=1 for all observations). 194 195 Returns 196 ------- 197 torch.Tensor 198 Predicted means mu, shape (batch,). 199 """ 200 batch_size = X.shape[0] 201 202 # Accumulate linear predictor starting from bias 203 eta = self.bias.expand(batch_size) 204 205 # Feature subnetwork contributions 206 for i, cfg in enumerate(self.feature_configs): 207 x_i = X[:, i] 208 net = self.feature_nets[cfg.name] 209 210 if cfg.feature_type == "categorical": 211 contrib = net(x_i.long()).squeeze(-1) 212 else: 213 contrib = net(x_i.float()).squeeze(-1) 214 215 eta = eta + contrib 216 217 # Interaction subnetwork contributions 218 for icfg in self.interaction_configs: 219 key = f"{icfg.feature_i}_x_{icfg.feature_j}" 220 i_idx = self.feature_name_to_idx[icfg.feature_i] 221 j_idx = self.feature_name_to_idx[icfg.feature_j] 222 223 x_i = X[:, i_idx].float() 224 x_j = X[:, j_idx].float() 225 226 contrib = self.interaction_nets[key](x_i, x_j).squeeze(-1) 227 eta = eta + contrib 228 229 # Exposure offset 230 if log_exposure is not None: 231 eta = eta + log_exposure 232 233 # Apply link inverse 234 return self._link_inverse(eta) 235 236 def _link_inverse(self, eta: torch.Tensor) -> torch.Tensor: 237 """Apply inverse link function to linear predictor.""" 238 if self.link == "log": 239 return torch.exp(eta) 240 elif self.link == "identity": 241 return eta 242 elif self.link == "logit": 243 return torch.sigmoid(eta) 244 else: 245 raise ValueError(f"Unknown link function: {self.link!r}") 246 247 def linear_predictor( 248 self, 249 X: torch.Tensor, 250 log_exposure: Optional[torch.Tensor] = None, 251 ) -> torch.Tensor: 252 """Return eta (linear predictor) without applying link inverse. 253 254 Useful for inspecting additive contributions before exponentiation. 255 """ 256 batch_size = X.shape[0] 257 eta = self.bias.expand(batch_size) 258 259 for i, cfg in enumerate(self.feature_configs): 260 x_i = X[:, i] 261 net = self.feature_nets[cfg.name] 262 if cfg.feature_type == "categorical": 263 contrib = net(x_i.long()).squeeze(-1) 264 else: 265 contrib = net(x_i.float()).squeeze(-1) 266 eta = eta + contrib 267 268 for icfg in self.interaction_configs: 269 key = f"{icfg.feature_i}_x_{icfg.feature_j}" 270 i_idx = self.feature_name_to_idx[icfg.feature_i] 271 j_idx = self.feature_name_to_idx[icfg.feature_j] 272 contrib = self.interaction_nets[key]( 273 X[:, i_idx].float(), X[:, j_idx].float() 274 ).squeeze(-1) 275 eta = eta + contrib 276 277 if log_exposure is not None: 278 eta = eta + log_exposure 279 280 return eta 281 282 def feature_contribution( 283 self, X: torch.Tensor, feature_name: str 284 ) -> torch.Tensor: 285 """Return the contribution of a single feature for each observation. 286 287 Useful for explaining individual predictions: the contribution from 288 feature i is exactly f_i(x_i), the output of that subnetwork. 289 290 Parameters 291 ---------- 292 X: 293 Feature matrix (batch, n_features). 294 feature_name: 295 Name of the feature to inspect. 296 297 Returns 298 ------- 299 torch.Tensor 300 Shape (batch,). Values of f_i(x_i) for each row. 301 """ 302 idx = self.feature_name_to_idx[feature_name] 303 cfg = self.feature_configs[idx] 304 net = self.feature_nets[feature_name] 305 x_i = X[:, idx] 306 307 if cfg.feature_type == "categorical": 308 return net(x_i.long()).squeeze(-1) 309 else: 310 return net(x_i.float()).squeeze(-1) 311 312 def project_monotone_weights(self) -> None: 313 """Enforce monotonicity constraints on all relevant subnetworks. 314 315 Call this after optimizer.step() in the training loop. 316 Does nothing for non-monotone and categorical features. 317 """ 318 for cfg in self.feature_configs: 319 if cfg.feature_type == "continuous" and cfg.monotonicity != "none": 320 net = self.feature_nets[cfg.name] 321 assert isinstance(net, FeatureNetwork) 322 net.project_weights() 323 324 def feature_importance(self) -> Dict[str, float]: 325 """Estimate feature importance as the L2 norm of output layer weights. 326 327 Larger norm = larger potential contribution from that feature. This 328 is a quick heuristic for feature selection — not a replacement for 329 proper permutation importance or SHAP values on the additive model. 330 331 Returns 332 ------- 333 Dict[str, float] 334 Feature name -> importance score, sorted descending. 335 """ 336 importances: Dict[str, float] = {} 337 338 for name, net in self.feature_nets.items(): 339 total_norm = 0.0 340 for param in net.parameters(): 341 total_norm += param.data.norm(2).item() ** 2 342 importances[name] = float(total_norm ** 0.5) 343 344 return dict(sorted(importances.items(), key=lambda x: x[1], reverse=True)) 345 346 @property 347 def n_features(self) -> int: 348 """Number of features this model was built for.""" 349 return len(self.feature_configs) 350 351 @property 352 def feature_names(self) -> List[str]: 353 """Ordered list of feature names.""" 354 return [cfg.name for cfg in self.feature_configs]
Actuarial Neural Additive Model.
Orchestrates one subnetwork per feature plus optional pairwise interaction networks. The output is:
eta = bias + sum_i f_i(x_i) + sum_{(i,j)} g_{ij}(x_i, x_j)
mu = link_inverse(eta + log_offset)
Parameters
feature_configs: Ordered list of FeatureConfig objects. The column order in X arrays passed to forward() must match this list. link: Link function. 'log' for Poisson/Tweedie/Gamma, 'identity' for Gaussian, 'logit' for binary. interaction_configs: Optional list of InteractionConfig for pairwise terms. hidden_sizes: Default hidden sizes for all subnetworks (overridden per-feature by FeatureConfig.hidden_sizes). dropout: Dropout rate applied within subnetworks.
111 def __init__( 112 self, 113 feature_configs: List[FeatureConfig], 114 link: LinkFunction = "log", 115 interaction_configs: Optional[List[InteractionConfig]] = None, 116 hidden_sizes: Optional[List[int]] = None, 117 dropout: float = 0.0, 118 ) -> None: 119 super().__init__() 120 121 self.feature_configs = feature_configs 122 self.link = link 123 self.interaction_configs = interaction_configs or [] 124 125 # Map feature names to column indices 126 self.feature_name_to_idx: Dict[str, int] = { 127 cfg.name: i for i, cfg in enumerate(feature_configs) 128 } 129 130 default_hidden = hidden_sizes or [64, 32] 131 132 # Build per-feature subnetworks 133 feature_nets: Dict[str, nn.Module] = {} 134 for cfg in feature_configs: 135 net_hidden = cfg.hidden_sizes or default_hidden 136 if cfg.feature_type == "continuous": 137 feature_nets[cfg.name] = FeatureNetwork( 138 hidden_sizes=net_hidden, 139 monotonicity=cfg.monotonicity, 140 dropout=dropout, 141 ) 142 elif cfg.feature_type == "categorical": 143 if cfg.n_categories is None: 144 raise ValueError( 145 f"Feature '{cfg.name}' is categorical but n_categories is not set." 146 ) 147 feature_nets[cfg.name] = CategoricalFeatureNetwork( 148 n_categories=cfg.n_categories, 149 embedding_dim=cfg.embedding_dim, 150 hidden_sizes=net_hidden, 151 dropout=dropout, 152 ) 153 else: 154 raise ValueError(f"Unknown feature_type: {cfg.feature_type!r}") 155 156 self.feature_nets = nn.ModuleDict(feature_nets) 157 158 # Build interaction subnetworks 159 interaction_nets: Dict[str, nn.Module] = {} 160 for icfg in self.interaction_configs: 161 key = f"{icfg.feature_i}_x_{icfg.feature_j}" 162 i_idx = self.feature_name_to_idx[icfg.feature_i] 163 j_idx = self.feature_name_to_idx[icfg.feature_j] 164 i_hidden = icfg.hidden_sizes or [32, 16] 165 interaction_nets[key] = InteractionNetwork( 166 feature_indices=(i_idx, j_idx), 167 hidden_sizes=i_hidden, 168 dropout=dropout, 169 ) 170 171 self.interaction_nets = nn.ModuleDict(interaction_nets) 172 173 # Scalar bias (learnable) 174 self.bias = nn.Parameter(torch.zeros(1))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
176 def forward( 177 self, 178 X: torch.Tensor, 179 log_exposure: Optional[torch.Tensor] = None, 180 ) -> torch.Tensor: 181 """Forward pass through all subnetworks. 182 183 Parameters 184 ---------- 185 X: 186 Feature matrix, shape (batch, n_features). Continuous features 187 should be pre-normalised. Categorical features should be integer 188 indices (will be cast to long internally). 189 log_exposure: 190 Log of exposure (e.g. log policy duration in years), shape 191 (batch,). Added as offset to the linear predictor before the 192 link function. If None, no offset applied (equivalent to 193 exposure=1 for all observations). 194 195 Returns 196 ------- 197 torch.Tensor 198 Predicted means mu, shape (batch,). 199 """ 200 batch_size = X.shape[0] 201 202 # Accumulate linear predictor starting from bias 203 eta = self.bias.expand(batch_size) 204 205 # Feature subnetwork contributions 206 for i, cfg in enumerate(self.feature_configs): 207 x_i = X[:, i] 208 net = self.feature_nets[cfg.name] 209 210 if cfg.feature_type == "categorical": 211 contrib = net(x_i.long()).squeeze(-1) 212 else: 213 contrib = net(x_i.float()).squeeze(-1) 214 215 eta = eta + contrib 216 217 # Interaction subnetwork contributions 218 for icfg in self.interaction_configs: 219 key = f"{icfg.feature_i}_x_{icfg.feature_j}" 220 i_idx = self.feature_name_to_idx[icfg.feature_i] 221 j_idx = self.feature_name_to_idx[icfg.feature_j] 222 223 x_i = X[:, i_idx].float() 224 x_j = X[:, j_idx].float() 225 226 contrib = self.interaction_nets[key](x_i, x_j).squeeze(-1) 227 eta = eta + contrib 228 229 # Exposure offset 230 if log_exposure is not None: 231 eta = eta + log_exposure 232 233 # Apply link inverse 234 return self._link_inverse(eta)
Forward pass through all subnetworks.
Parameters
X: Feature matrix, shape (batch, n_features). Continuous features should be pre-normalised. Categorical features should be integer indices (will be cast to long internally). log_exposure: Log of exposure (e.g. log policy duration in years), shape (batch,). Added as offset to the linear predictor before the link function. If None, no offset applied (equivalent to exposure=1 for all observations).
Returns
torch.Tensor Predicted means mu, shape (batch,).
247 def linear_predictor( 248 self, 249 X: torch.Tensor, 250 log_exposure: Optional[torch.Tensor] = None, 251 ) -> torch.Tensor: 252 """Return eta (linear predictor) without applying link inverse. 253 254 Useful for inspecting additive contributions before exponentiation. 255 """ 256 batch_size = X.shape[0] 257 eta = self.bias.expand(batch_size) 258 259 for i, cfg in enumerate(self.feature_configs): 260 x_i = X[:, i] 261 net = self.feature_nets[cfg.name] 262 if cfg.feature_type == "categorical": 263 contrib = net(x_i.long()).squeeze(-1) 264 else: 265 contrib = net(x_i.float()).squeeze(-1) 266 eta = eta + contrib 267 268 for icfg in self.interaction_configs: 269 key = f"{icfg.feature_i}_x_{icfg.feature_j}" 270 i_idx = self.feature_name_to_idx[icfg.feature_i] 271 j_idx = self.feature_name_to_idx[icfg.feature_j] 272 contrib = self.interaction_nets[key]( 273 X[:, i_idx].float(), X[:, j_idx].float() 274 ).squeeze(-1) 275 eta = eta + contrib 276 277 if log_exposure is not None: 278 eta = eta + log_exposure 279 280 return eta
Return eta (linear predictor) without applying link inverse.
Useful for inspecting additive contributions before exponentiation.
282 def feature_contribution( 283 self, X: torch.Tensor, feature_name: str 284 ) -> torch.Tensor: 285 """Return the contribution of a single feature for each observation. 286 287 Useful for explaining individual predictions: the contribution from 288 feature i is exactly f_i(x_i), the output of that subnetwork. 289 290 Parameters 291 ---------- 292 X: 293 Feature matrix (batch, n_features). 294 feature_name: 295 Name of the feature to inspect. 296 297 Returns 298 ------- 299 torch.Tensor 300 Shape (batch,). Values of f_i(x_i) for each row. 301 """ 302 idx = self.feature_name_to_idx[feature_name] 303 cfg = self.feature_configs[idx] 304 net = self.feature_nets[feature_name] 305 x_i = X[:, idx] 306 307 if cfg.feature_type == "categorical": 308 return net(x_i.long()).squeeze(-1) 309 else: 310 return net(x_i.float()).squeeze(-1)
Return the contribution of a single feature for each observation.
Useful for explaining individual predictions: the contribution from feature i is exactly f_i(x_i), the output of that subnetwork.
Parameters
X: Feature matrix (batch, n_features). feature_name: Name of the feature to inspect.
Returns
torch.Tensor Shape (batch,). Values of f_i(x_i) for each row.
312 def project_monotone_weights(self) -> None: 313 """Enforce monotonicity constraints on all relevant subnetworks. 314 315 Call this after optimizer.step() in the training loop. 316 Does nothing for non-monotone and categorical features. 317 """ 318 for cfg in self.feature_configs: 319 if cfg.feature_type == "continuous" and cfg.monotonicity != "none": 320 net = self.feature_nets[cfg.name] 321 assert isinstance(net, FeatureNetwork) 322 net.project_weights()
Enforce monotonicity constraints on all relevant subnetworks.
Call this after optimizer.step() in the training loop. Does nothing for non-monotone and categorical features.
324 def feature_importance(self) -> Dict[str, float]: 325 """Estimate feature importance as the L2 norm of output layer weights. 326 327 Larger norm = larger potential contribution from that feature. This 328 is a quick heuristic for feature selection — not a replacement for 329 proper permutation importance or SHAP values on the additive model. 330 331 Returns 332 ------- 333 Dict[str, float] 334 Feature name -> importance score, sorted descending. 335 """ 336 importances: Dict[str, float] = {} 337 338 for name, net in self.feature_nets.items(): 339 total_norm = 0.0 340 for param in net.parameters(): 341 total_norm += param.data.norm(2).item() ** 2 342 importances[name] = float(total_norm ** 0.5) 343 344 return dict(sorted(importances.items(), key=lambda x: x[1], reverse=True))
Estimate feature importance as the L2 norm of output layer weights.
Larger norm = larger potential contribution from that feature. This is a quick heuristic for feature selection — not a replacement for proper permutation importance or SHAP values on the additive model.
Returns
Dict[str, float] Feature name -> importance score, sorted descending.
40@dataclass 41class FeatureConfig: 42 """Configuration for a single feature. 43 44 Parameters 45 ---------- 46 name: 47 Feature name (used in output tables and plots). 48 feature_type: 49 'continuous' or 'categorical'. 50 monotonicity: 51 Monotonicity constraint. Only applies to continuous features. 52 n_categories: 53 Number of categories (required if feature_type='categorical'). 54 embedding_dim: 55 Embedding dimension for categorical features. 56 hidden_sizes: 57 Hidden layer widths for this feature's subnetwork. 58 """ 59 60 name: str 61 feature_type: Literal["continuous", "categorical"] = "continuous" 62 monotonicity: MonotonicityDirection = "none" 63 n_categories: Optional[int] = None 64 embedding_dim: int = 4 65 hidden_sizes: Optional[List[int]] = None
Configuration for a single feature.
Parameters
name: Feature name (used in output tables and plots). feature_type: 'continuous' or 'categorical'. monotonicity: Monotonicity constraint. Only applies to continuous features. n_categories: Number of categories (required if feature_type='categorical'). embedding_dim: Embedding dimension for categorical features. hidden_sizes: Hidden layer widths for this feature's subnetwork.
68@dataclass 69class InteractionConfig: 70 """Configuration for a pairwise interaction subnetwork. 71 72 Parameters 73 ---------- 74 feature_i, feature_j: 75 Names of the two interacting features. Must be in the feature list. 76 hidden_sizes: 77 Hidden layer widths for the interaction subnetwork. 78 """ 79 80 feature_i: str 81 feature_j: str 82 hidden_sizes: Optional[List[int]] = None
Configuration for a pairwise interaction subnetwork.
Parameters
feature_i, feature_j: Names of the two interacting features. Must be in the feature list. hidden_sizes: Hidden layer widths for the interaction subnetwork.
66class FeatureNetwork(nn.Module): 67 """Single-feature MLP subnetwork. 68 69 Takes a scalar input (one feature value) and outputs a scalar 70 contribution f_i(x_i). The contribution is offset by its mean over a 71 representative sample so that the bias term in the full model captures 72 the population mean prediction. 73 74 Parameters 75 ---------- 76 hidden_sizes: 77 Width of each hidden layer. E.g. [64, 32] gives two hidden layers. 78 activation: 79 'relu' (default) or 'exu'. ReLU is more stable; ExU more expressive. 80 When 'exu' is specified, ExUActivation replaces the nn.Linear + 81 nn.ReLU pair for each hidden layer (ExU fuses the linear transform 82 into the activation). 83 monotonicity: 84 'increasing', 'decreasing', or 'none'. Enforced by projecting 85 weight matrices onto the non-negative (or non-positive) orthant 86 after each gradient step via project_weights(). 87 dropout: 88 Dropout rate applied between hidden layers. 0.0 disables dropout. 89 """ 90 91 def __init__( 92 self, 93 hidden_sizes: list[int] | None = None, 94 activation: Literal["relu", "exu"] = "relu", 95 monotonicity: MonotonicityDirection = "none", 96 dropout: float = 0.0, 97 ) -> None: 98 super().__init__() 99 100 if hidden_sizes is None: 101 hidden_sizes = [64, 32] 102 103 self.hidden_sizes = hidden_sizes 104 self.activation_name = activation 105 self.monotonicity = monotonicity 106 self.dropout_rate = dropout 107 108 layers: list[nn.Module] = [] 109 in_dim = 1 110 111 for i, out_dim in enumerate(hidden_sizes): 112 if activation == "relu": 113 # Standard linear + ReLU pair. 114 linear = nn.Linear(in_dim, out_dim) 115 # Weight initialisation: smaller init stabilises training with 116 # many subnetworks summing together. 117 nn.init.xavier_uniform_(linear.weight, gain=0.5) 118 nn.init.zeros_(linear.bias) 119 layers.append(linear) 120 layers.append(nn.ReLU()) 121 elif activation == "exu": 122 # P0-3 fix: wire up ExUActivation properly. 123 # ExU fuses the linear transform into the activation — no 124 # preceding nn.Linear. The ExUActivation module takes 125 # (in_features, out_features) and outputs (batch, out_features). 126 layers.append(ExUActivation(in_features=in_dim, out_features=out_dim)) 127 else: 128 raise ValueError(f"Unknown activation: {activation!r}") 129 130 if dropout > 0.0 and i < len(hidden_sizes) - 1: 131 layers.append(nn.Dropout(p=dropout)) 132 133 in_dim = out_dim 134 135 # Output layer: single scalar 136 output_layer = nn.Linear(in_dim, 1) 137 nn.init.xavier_uniform_(output_layer.weight, gain=0.1) 138 nn.init.zeros_(output_layer.bias) 139 layers.append(output_layer) 140 141 self.network = nn.Sequential(*layers) 142 143 def forward(self, x: torch.Tensor) -> torch.Tensor: 144 """Forward pass. 145 146 Parameters 147 ---------- 148 x: 149 Shape (batch,) or (batch, 1). Single feature values. 150 151 Returns 152 ------- 153 torch.Tensor 154 Shape (batch, 1). Feature contribution f_i(x_i). 155 """ 156 if x.dim() == 1: 157 x = x.unsqueeze(-1) 158 return self.network(x) 159 160 def project_weights(self) -> None: 161 """Enforce monotonicity by clamping weight matrices in-place. 162 163 For a ReLU network, all-non-negative weights guarantee a 164 non-decreasing function (Dykstra's projection onto the positive 165 orthant). For decreasing, clamp to non-positive. 166 167 Call this after every optimizer.step() during training. 168 """ 169 if self.monotonicity == "none": 170 return 171 172 for module in self.network.modules(): 173 if isinstance(module, nn.Linear): 174 if self.monotonicity == "increasing": 175 module.weight.data.clamp_(min=0.0) 176 elif self.monotonicity == "decreasing": 177 module.weight.data.clamp_(max=0.0) 178 elif isinstance(module, ExUActivation): 179 if self.monotonicity == "increasing": 180 module.weights.data.clamp_(min=0.0) 181 elif self.monotonicity == "decreasing": 182 module.weights.data.clamp_(max=0.0) 183 184 def feature_range( 185 self, x_min: float, x_max: float, n_points: int = 200 186 ) -> tuple[torch.Tensor, torch.Tensor]: 187 """Evaluate shape function over a grid. 188 189 Returns (x_grid, f_values) for plotting shape curves. 190 """ 191 self.eval() 192 with torch.no_grad(): 193 x_grid = torch.linspace(x_min, x_max, n_points) 194 f_values = self.forward(x_grid).squeeze(-1) 195 return x_grid, f_values
Single-feature MLP subnetwork.
Takes a scalar input (one feature value) and outputs a scalar contribution f_i(x_i). The contribution is offset by its mean over a representative sample so that the bias term in the full model captures the population mean prediction.
Parameters
hidden_sizes: Width of each hidden layer. E.g. [64, 32] gives two hidden layers. activation: 'relu' (default) or 'exu'. ReLU is more stable; ExU more expressive. When 'exu' is specified, ExUActivation replaces the nn.Linear + nn.ReLU pair for each hidden layer (ExU fuses the linear transform into the activation). monotonicity: 'increasing', 'decreasing', or 'none'. Enforced by projecting weight matrices onto the non-negative (or non-positive) orthant after each gradient step via project_weights(). dropout: Dropout rate applied between hidden layers. 0.0 disables dropout.
91 def __init__( 92 self, 93 hidden_sizes: list[int] | None = None, 94 activation: Literal["relu", "exu"] = "relu", 95 monotonicity: MonotonicityDirection = "none", 96 dropout: float = 0.0, 97 ) -> None: 98 super().__init__() 99 100 if hidden_sizes is None: 101 hidden_sizes = [64, 32] 102 103 self.hidden_sizes = hidden_sizes 104 self.activation_name = activation 105 self.monotonicity = monotonicity 106 self.dropout_rate = dropout 107 108 layers: list[nn.Module] = [] 109 in_dim = 1 110 111 for i, out_dim in enumerate(hidden_sizes): 112 if activation == "relu": 113 # Standard linear + ReLU pair. 114 linear = nn.Linear(in_dim, out_dim) 115 # Weight initialisation: smaller init stabilises training with 116 # many subnetworks summing together. 117 nn.init.xavier_uniform_(linear.weight, gain=0.5) 118 nn.init.zeros_(linear.bias) 119 layers.append(linear) 120 layers.append(nn.ReLU()) 121 elif activation == "exu": 122 # P0-3 fix: wire up ExUActivation properly. 123 # ExU fuses the linear transform into the activation — no 124 # preceding nn.Linear. The ExUActivation module takes 125 # (in_features, out_features) and outputs (batch, out_features). 126 layers.append(ExUActivation(in_features=in_dim, out_features=out_dim)) 127 else: 128 raise ValueError(f"Unknown activation: {activation!r}") 129 130 if dropout > 0.0 and i < len(hidden_sizes) - 1: 131 layers.append(nn.Dropout(p=dropout)) 132 133 in_dim = out_dim 134 135 # Output layer: single scalar 136 output_layer = nn.Linear(in_dim, 1) 137 nn.init.xavier_uniform_(output_layer.weight, gain=0.1) 138 nn.init.zeros_(output_layer.bias) 139 layers.append(output_layer) 140 141 self.network = nn.Sequential(*layers)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
143 def forward(self, x: torch.Tensor) -> torch.Tensor: 144 """Forward pass. 145 146 Parameters 147 ---------- 148 x: 149 Shape (batch,) or (batch, 1). Single feature values. 150 151 Returns 152 ------- 153 torch.Tensor 154 Shape (batch, 1). Feature contribution f_i(x_i). 155 """ 156 if x.dim() == 1: 157 x = x.unsqueeze(-1) 158 return self.network(x)
Forward pass.
Parameters
x: Shape (batch,) or (batch, 1). Single feature values.
Returns
torch.Tensor Shape (batch, 1). Feature contribution f_i(x_i).
160 def project_weights(self) -> None: 161 """Enforce monotonicity by clamping weight matrices in-place. 162 163 For a ReLU network, all-non-negative weights guarantee a 164 non-decreasing function (Dykstra's projection onto the positive 165 orthant). For decreasing, clamp to non-positive. 166 167 Call this after every optimizer.step() during training. 168 """ 169 if self.monotonicity == "none": 170 return 171 172 for module in self.network.modules(): 173 if isinstance(module, nn.Linear): 174 if self.monotonicity == "increasing": 175 module.weight.data.clamp_(min=0.0) 176 elif self.monotonicity == "decreasing": 177 module.weight.data.clamp_(max=0.0) 178 elif isinstance(module, ExUActivation): 179 if self.monotonicity == "increasing": 180 module.weights.data.clamp_(min=0.0) 181 elif self.monotonicity == "decreasing": 182 module.weights.data.clamp_(max=0.0)
Enforce monotonicity by clamping weight matrices in-place.
For a ReLU network, all-non-negative weights guarantee a non-decreasing function (Dykstra's projection onto the positive orthant). For decreasing, clamp to non-positive.
Call this after every optimizer.step() during training.
184 def feature_range( 185 self, x_min: float, x_max: float, n_points: int = 200 186 ) -> tuple[torch.Tensor, torch.Tensor]: 187 """Evaluate shape function over a grid. 188 189 Returns (x_grid, f_values) for plotting shape curves. 190 """ 191 self.eval() 192 with torch.no_grad(): 193 x_grid = torch.linspace(x_min, x_max, n_points) 194 f_values = self.forward(x_grid).squeeze(-1) 195 return x_grid, f_values
Evaluate shape function over a grid.
Returns (x_grid, f_values) for plotting shape curves.
198class CategoricalFeatureNetwork(nn.Module): 199 """Subnetwork for categorical features using an embedding layer. 200 201 Each category maps to a learned embedding (dim: embedding_dim), which 202 then passes through a small MLP. This allows the model to discover 203 structure in category space (e.g. similar vehicle groups cluster 204 together) without requiring manual one-hot encoding. 205 206 For regulatory documentation, the output for each category level can be 207 extracted as a relativity table — exactly like a GLM factor table. 208 """ 209 210 def __init__( 211 self, 212 n_categories: int, 213 embedding_dim: int = 4, 214 hidden_sizes: list[int] | None = None, 215 dropout: float = 0.0, 216 ) -> None: 217 super().__init__() 218 219 if hidden_sizes is None: 220 hidden_sizes = [32] 221 222 self.n_categories = n_categories 223 self.embedding_dim = embedding_dim 224 225 self.embedding = nn.Embedding(n_categories, embedding_dim) 226 nn.init.normal_(self.embedding.weight, mean=0.0, std=0.1) 227 228 layers: list[nn.Module] = [] 229 in_dim = embedding_dim 230 231 for i, out_dim in enumerate(hidden_sizes): 232 linear = nn.Linear(in_dim, out_dim) 233 nn.init.xavier_uniform_(linear.weight, gain=0.5) 234 nn.init.zeros_(linear.bias) 235 layers.append(linear) 236 layers.append(nn.ReLU()) 237 if dropout > 0.0 and i < len(hidden_sizes) - 1: 238 layers.append(nn.Dropout(p=dropout)) 239 in_dim = out_dim 240 241 output_layer = nn.Linear(in_dim, 1) 242 nn.init.xavier_uniform_(output_layer.weight, gain=0.1) 243 nn.init.zeros_(output_layer.bias) 244 layers.append(output_layer) 245 246 self.network = nn.Sequential(*layers) 247 248 def forward(self, x: torch.Tensor) -> torch.Tensor: 249 """Forward pass. 250 251 Parameters 252 ---------- 253 x: 254 Shape (batch,). Integer category indices. 255 256 Returns 257 ------- 258 torch.Tensor 259 Shape (batch, 1). Category contribution. 260 """ 261 embedded = self.embedding(x.long()) # (batch, embedding_dim) 262 return self.network(embedded) 263 264 def category_table(self) -> dict[int, float]: 265 """Extract per-category contributions as a relativity table. 266 267 Returns a dict mapping category index to scalar contribution value. 268 Useful for regulatory documentation and GLM comparison. 269 """ 270 self.eval() 271 with torch.no_grad(): 272 indices = torch.arange(self.n_categories) 273 contribs = self.forward(indices).squeeze(-1) 274 return {int(i): float(v) for i, v in enumerate(contribs)}
Subnetwork for categorical features using an embedding layer.
Each category maps to a learned embedding (dim: embedding_dim), which then passes through a small MLP. This allows the model to discover structure in category space (e.g. similar vehicle groups cluster together) without requiring manual one-hot encoding.
For regulatory documentation, the output for each category level can be extracted as a relativity table — exactly like a GLM factor table.
210 def __init__( 211 self, 212 n_categories: int, 213 embedding_dim: int = 4, 214 hidden_sizes: list[int] | None = None, 215 dropout: float = 0.0, 216 ) -> None: 217 super().__init__() 218 219 if hidden_sizes is None: 220 hidden_sizes = [32] 221 222 self.n_categories = n_categories 223 self.embedding_dim = embedding_dim 224 225 self.embedding = nn.Embedding(n_categories, embedding_dim) 226 nn.init.normal_(self.embedding.weight, mean=0.0, std=0.1) 227 228 layers: list[nn.Module] = [] 229 in_dim = embedding_dim 230 231 for i, out_dim in enumerate(hidden_sizes): 232 linear = nn.Linear(in_dim, out_dim) 233 nn.init.xavier_uniform_(linear.weight, gain=0.5) 234 nn.init.zeros_(linear.bias) 235 layers.append(linear) 236 layers.append(nn.ReLU()) 237 if dropout > 0.0 and i < len(hidden_sizes) - 1: 238 layers.append(nn.Dropout(p=dropout)) 239 in_dim = out_dim 240 241 output_layer = nn.Linear(in_dim, 1) 242 nn.init.xavier_uniform_(output_layer.weight, gain=0.1) 243 nn.init.zeros_(output_layer.bias) 244 layers.append(output_layer) 245 246 self.network = nn.Sequential(*layers)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
248 def forward(self, x: torch.Tensor) -> torch.Tensor: 249 """Forward pass. 250 251 Parameters 252 ---------- 253 x: 254 Shape (batch,). Integer category indices. 255 256 Returns 257 ------- 258 torch.Tensor 259 Shape (batch, 1). Category contribution. 260 """ 261 embedded = self.embedding(x.long()) # (batch, embedding_dim) 262 return self.network(embedded)
Forward pass.
Parameters
x: Shape (batch,). Integer category indices.
Returns
torch.Tensor Shape (batch, 1). Category contribution.
264 def category_table(self) -> dict[int, float]: 265 """Extract per-category contributions as a relativity table. 266 267 Returns a dict mapping category index to scalar contribution value. 268 Useful for regulatory documentation and GLM comparison. 269 """ 270 self.eval() 271 with torch.no_grad(): 272 indices = torch.arange(self.n_categories) 273 contribs = self.forward(indices).squeeze(-1) 274 return {int(i): float(v) for i, v in enumerate(contribs)}
Extract per-category contributions as a relativity table.
Returns a dict mapping category index to scalar contribution value. Useful for regulatory documentation and GLM comparison.
26class InteractionNetwork(nn.Module): 27 """Pairwise interaction subnetwork g_{ij}(x_i, x_j). 28 29 Takes two feature values and learns their joint effect. The network 30 is intentionally shallow to avoid overfitting and to keep the 31 interaction contribution interpretable as a 2D surface. 32 33 Parameters 34 ---------- 35 feature_indices: 36 Tuple (i, j) identifying which features this network handles. 37 Used for bookkeeping and shape function export. 38 hidden_sizes: 39 Width of each hidden layer. Shallower than single-feature nets 40 is recommended — interactions should be simple corrections to 41 the additive baseline. 42 dropout: 43 Dropout rate between hidden layers. 44 """ 45 46 def __init__( 47 self, 48 feature_indices: tuple[int, int], 49 hidden_sizes: list[int] | None = None, 50 dropout: float = 0.0, 51 ) -> None: 52 super().__init__() 53 54 if hidden_sizes is None: 55 hidden_sizes = [32, 16] 56 57 self.feature_indices = feature_indices 58 self.hidden_sizes = hidden_sizes 59 60 layers: list[nn.Module] = [] 61 in_dim = 2 # two features concatenated 62 63 for i, out_dim in enumerate(hidden_sizes): 64 linear = nn.Linear(in_dim, out_dim) 65 nn.init.xavier_uniform_(linear.weight, gain=0.5) 66 nn.init.zeros_(linear.bias) 67 layers.append(linear) 68 layers.append(nn.ReLU()) 69 if dropout > 0.0 and i < len(hidden_sizes) - 1: 70 layers.append(nn.Dropout(p=dropout)) 71 in_dim = out_dim 72 73 output_layer = nn.Linear(in_dim, 1) 74 nn.init.xavier_uniform_(output_layer.weight, gain=0.1) 75 nn.init.zeros_(output_layer.bias) 76 layers.append(output_layer) 77 78 self.network = nn.Sequential(*layers) 79 80 def forward(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor: 81 """Forward pass. 82 83 Parameters 84 ---------- 85 x_i, x_j: 86 Shape (batch,) or (batch, 1). Values for the two interacting 87 features. 88 89 Returns 90 ------- 91 torch.Tensor 92 Shape (batch, 1). Interaction contribution g_{ij}(x_i, x_j). 93 """ 94 if x_i.dim() == 1: 95 x_i = x_i.unsqueeze(-1) 96 if x_j.dim() == 1: 97 x_j = x_j.unsqueeze(-1) 98 combined = torch.cat([x_i, x_j], dim=-1) # (batch, 2) 99 return self.network(combined) 100 101 def interaction_grid( 102 self, 103 xi_min: float, 104 xi_max: float, 105 xj_min: float, 106 xj_max: float, 107 n_points: int = 50, 108 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 109 """Evaluate interaction surface over a 2D grid. 110 111 Returns (xi_grid, xj_grid, g_values) where g_values is an 112 (n_points, n_points) tensor representing the interaction surface. 113 Useful for heatmap visualisation. 114 """ 115 self.eval() 116 with torch.no_grad(): 117 xi = torch.linspace(xi_min, xi_max, n_points) 118 xj = torch.linspace(xj_min, xj_max, n_points) 119 xi_grid, xj_grid = torch.meshgrid(xi, xj, indexing="ij") 120 xi_flat = xi_grid.reshape(-1) 121 xj_flat = xj_grid.reshape(-1) 122 g_flat = self.forward(xi_flat, xj_flat).squeeze(-1) 123 g_values = g_flat.reshape(n_points, n_points) 124 return xi_grid, xj_grid, g_values
Pairwise interaction subnetwork g_{ij}(x_i, x_j).
Takes two feature values and learns their joint effect. The network is intentionally shallow to avoid overfitting and to keep the interaction contribution interpretable as a 2D surface.
Parameters
feature_indices: Tuple (i, j) identifying which features this network handles. Used for bookkeeping and shape function export. hidden_sizes: Width of each hidden layer. Shallower than single-feature nets is recommended — interactions should be simple corrections to the additive baseline. dropout: Dropout rate between hidden layers.
46 def __init__( 47 self, 48 feature_indices: tuple[int, int], 49 hidden_sizes: list[int] | None = None, 50 dropout: float = 0.0, 51 ) -> None: 52 super().__init__() 53 54 if hidden_sizes is None: 55 hidden_sizes = [32, 16] 56 57 self.feature_indices = feature_indices 58 self.hidden_sizes = hidden_sizes 59 60 layers: list[nn.Module] = [] 61 in_dim = 2 # two features concatenated 62 63 for i, out_dim in enumerate(hidden_sizes): 64 linear = nn.Linear(in_dim, out_dim) 65 nn.init.xavier_uniform_(linear.weight, gain=0.5) 66 nn.init.zeros_(linear.bias) 67 layers.append(linear) 68 layers.append(nn.ReLU()) 69 if dropout > 0.0 and i < len(hidden_sizes) - 1: 70 layers.append(nn.Dropout(p=dropout)) 71 in_dim = out_dim 72 73 output_layer = nn.Linear(in_dim, 1) 74 nn.init.xavier_uniform_(output_layer.weight, gain=0.1) 75 nn.init.zeros_(output_layer.bias) 76 layers.append(output_layer) 77 78 self.network = nn.Sequential(*layers)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
80 def forward(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor: 81 """Forward pass. 82 83 Parameters 84 ---------- 85 x_i, x_j: 86 Shape (batch,) or (batch, 1). Values for the two interacting 87 features. 88 89 Returns 90 ------- 91 torch.Tensor 92 Shape (batch, 1). Interaction contribution g_{ij}(x_i, x_j). 93 """ 94 if x_i.dim() == 1: 95 x_i = x_i.unsqueeze(-1) 96 if x_j.dim() == 1: 97 x_j = x_j.unsqueeze(-1) 98 combined = torch.cat([x_i, x_j], dim=-1) # (batch, 2) 99 return self.network(combined)
Forward pass.
Parameters
x_i, x_j: Shape (batch,) or (batch, 1). Values for the two interacting features.
Returns
torch.Tensor Shape (batch, 1). Interaction contribution g_{ij}(x_i, x_j).
101 def interaction_grid( 102 self, 103 xi_min: float, 104 xi_max: float, 105 xj_min: float, 106 xj_max: float, 107 n_points: int = 50, 108 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 109 """Evaluate interaction surface over a 2D grid. 110 111 Returns (xi_grid, xj_grid, g_values) where g_values is an 112 (n_points, n_points) tensor representing the interaction surface. 113 Useful for heatmap visualisation. 114 """ 115 self.eval() 116 with torch.no_grad(): 117 xi = torch.linspace(xi_min, xi_max, n_points) 118 xj = torch.linspace(xj_min, xj_max, n_points) 119 xi_grid, xj_grid = torch.meshgrid(xi, xj, indexing="ij") 120 xi_flat = xi_grid.reshape(-1) 121 xj_flat = xj_grid.reshape(-1) 122 g_flat = self.forward(xi_flat, xj_flat).squeeze(-1) 123 g_values = g_flat.reshape(n_points, n_points) 124 return xi_grid, xj_grid, g_values
Evaluate interaction surface over a 2D grid.
Returns (xi_grid, xj_grid, g_values) where g_values is an (n_points, n_points) tensor representing the interaction surface. Useful for heatmap visualisation.
113class ANAMTrainer: 114 """Manages the training loop for ANAMModel. 115 116 Parameters 117 ---------- 118 model: 119 The ANAMModel to train. Modified in-place. 120 config: 121 Training hyperparameters. 122 """ 123 124 def __init__(self, model: ANAMModel, config: TrainingConfig) -> None: 125 self.model = model 126 self.config = config 127 128 if config.device is None: 129 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 130 else: 131 self.device = torch.device(config.device) 132 133 self.model.to(self.device) 134 self.history = TrainingHistory() 135 136 def fit( 137 self, 138 X: np.ndarray, 139 y: np.ndarray, 140 exposure: Optional[np.ndarray] = None, 141 ) -> TrainingHistory: 142 """Train the model. 143 144 Parameters 145 ---------- 146 X: 147 Feature matrix, shape (n, n_features). Continuous features 148 should be normalised before calling fit(). 149 y: 150 Target vector, shape (n,). Claim counts, rates, or severities. 151 exposure: 152 Exposure vector, shape (n,). Policy years or similar. If None, 153 uniform exposure (all 1.0) is assumed. 154 155 Returns 156 ------- 157 TrainingHistory 158 Training and validation loss per epoch. 159 """ 160 n = len(y) 161 162 if exposure is None: 163 exposure = np.ones(n, dtype=np.float32) 164 165 # Convert to tensors 166 X_t = torch.tensor(X, dtype=torch.float32) 167 y_t = torch.tensor(y, dtype=torch.float32) 168 exp_t = torch.tensor(exposure, dtype=torch.float32) 169 170 # Train/val split 171 n_val = max(1, int(n * self.config.val_fraction)) 172 perm = torch.randperm(n) 173 val_idx = perm[:n_val] 174 train_idx = perm[n_val:] 175 176 X_train, y_train, exp_train = X_t[train_idx], y_t[train_idx], exp_t[train_idx] 177 X_val, y_val, exp_val = X_t[val_idx], y_t[val_idx], exp_t[val_idx] 178 179 # Precompute feature ranges for smoothness penalty 180 feature_ranges = self._compute_feature_ranges(X_train) 181 182 # DataLoader 183 train_dataset = TensorDataset(X_train, y_train, exp_train) 184 train_loader = DataLoader( 185 train_dataset, 186 batch_size=self.config.batch_size, 187 shuffle=True, 188 drop_last=False, 189 ) 190 191 # Optimiser and LR scheduler 192 optimizer = torch.optim.Adam( 193 self.model.parameters(), 194 lr=self.config.learning_rate, 195 weight_decay=0.0, # L2 handled manually for flexibility 196 ) 197 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 198 optimizer, T_max=self.config.n_epochs, eta_min=self.config.learning_rate * 0.01 199 ) 200 201 best_val_loss = float("inf") 202 best_state = copy.deepcopy(self.model.state_dict()) 203 patience_counter = 0 204 205 for epoch in range(self.config.n_epochs): 206 t0 = time.time() 207 train_loss = self._train_epoch( 208 train_loader, optimizer, feature_ranges 209 ) 210 val_loss = self._evaluate(X_val, y_val, exp_val) 211 212 scheduler.step() 213 214 self.history.train_loss.append(train_loss) 215 self.history.val_loss.append(val_loss) 216 self.history.epoch_times.append(time.time() - t0) 217 218 if self.config.verbose > 0 and (epoch + 1) % self.config.verbose == 0: 219 print( 220 f"Epoch {epoch + 1:4d}/{self.config.n_epochs} | " 221 f"train={train_loss:.6f} | val={val_loss:.6f}" 222 ) 223 224 # Early stopping 225 if val_loss < best_val_loss - self.config.min_delta: 226 best_val_loss = val_loss 227 best_state = copy.deepcopy(self.model.state_dict()) 228 self.history.best_epoch = epoch 229 patience_counter = 0 230 else: 231 patience_counter += 1 232 if patience_counter >= self.config.patience: 233 if self.config.verbose > 0: 234 print(f"Early stopping at epoch {epoch + 1}.") 235 self.history.stopped_early = True 236 break 237 238 # Restore best weights 239 self.model.load_state_dict(best_state) 240 return self.history 241 242 def _train_epoch( 243 self, 244 loader: DataLoader, 245 optimizer: torch.optim.Optimizer, 246 feature_ranges: Dict[str, Tuple[float, float]], 247 ) -> float: 248 """One pass through the training data.""" 249 self.model.train() 250 total_loss = 0.0 251 total_weight = 0.0 252 253 for X_batch, y_batch, exp_batch in loader: 254 X_batch = X_batch.to(self.device) 255 y_batch = y_batch.to(self.device) 256 exp_batch = exp_batch.to(self.device) 257 258 optimizer.zero_grad() 259 260 # Log exposure offset for log-link models 261 log_exp = torch.log(exp_batch.clamp(min=1e-8)) if self.model.link == "log" else None 262 263 # Forward pass 264 y_pred = self.model(X_batch, log_exposure=log_exp) 265 266 # Distributional loss 267 loss = self._distributional_loss(y_pred, y_batch, exp_batch) 268 269 # Smoothness penalty over all continuous features 270 if self.config.lambda_smooth > 0.0: 271 for cfg in self.model.feature_configs: 272 if cfg.feature_type == "continuous": 273 x_min, x_max = feature_ranges[cfg.name] 274 net = self.model.feature_nets[cfg.name] 275 loss = loss + smoothness_penalty( 276 net, x_min, x_max, 277 n_points=self.config.smooth_n_points, 278 lambda_smooth=self.config.lambda_smooth, 279 ) 280 281 # L2 ridge 282 if self.config.lambda_l2 > 0.0: 283 all_nets = list(self.model.feature_nets.values()) + list( 284 self.model.interaction_nets.values() 285 ) 286 loss = loss + l2_ridge_penalty(all_nets, self.config.lambda_l2) 287 288 # L1 sparsity 289 if self.config.lambda_l1 > 0.0: 290 all_nets = list(self.model.feature_nets.values()) 291 loss = loss + l1_sparsity_penalty(all_nets, self.config.lambda_l1) 292 293 loss.backward() 294 optimizer.step() 295 296 # Monotonicity projection (Dykstra step) 297 self.model.project_monotone_weights() 298 299 batch_weight = exp_batch.sum().item() 300 total_loss += loss.item() * batch_weight 301 total_weight += batch_weight 302 303 return total_loss / max(total_weight, 1e-8) 304 305 def _evaluate( 306 self, 307 X: torch.Tensor, 308 y: torch.Tensor, 309 exposure: torch.Tensor, 310 ) -> float: 311 """Evaluate distributional loss on a dataset (no regularisation).""" 312 self.model.eval() 313 with torch.no_grad(): 314 X = X.to(self.device) 315 y = y.to(self.device) 316 exposure = exposure.to(self.device) 317 318 log_exp = torch.log(exposure.clamp(min=1e-8)) if self.model.link == "log" else None 319 y_pred = self.model(X, log_exposure=log_exp) 320 loss = self._distributional_loss(y_pred, y, exposure) 321 322 return float(loss.item()) 323 324 def _distributional_loss( 325 self, 326 y_pred: torch.Tensor, 327 y_true: torch.Tensor, 328 weights: torch.Tensor, 329 ) -> torch.Tensor: 330 """Compute weighted distributional loss.""" 331 cfg = self.config 332 if cfg.loss == "poisson": 333 return poisson_deviance(y_pred, y_true, weights) 334 elif cfg.loss == "tweedie": 335 return tweedie_deviance(y_pred, y_true, p=cfg.tweedie_p, weights=weights) 336 elif cfg.loss == "gamma": 337 return gamma_deviance(y_pred, y_true, weights) 338 elif cfg.loss == "mse": 339 err = (y_pred - y_true) ** 2 340 return (weights * err).sum() / weights.sum().clamp(min=1e-8) 341 else: 342 raise ValueError(f"Unknown loss: {cfg.loss!r}") 343 344 def _compute_feature_ranges( 345 self, X_train: torch.Tensor 346 ) -> Dict[str, Tuple[float, float]]: 347 """Compute per-feature min/max ranges from training data.""" 348 ranges: Dict[str, Tuple[float, float]] = {} 349 for i, cfg in enumerate(self.model.feature_configs): 350 if cfg.feature_type == "continuous": 351 col = X_train[:, i] 352 ranges[cfg.name] = (float(col.min().item()), float(col.max().item())) 353 return ranges
Manages the training loop for ANAMModel.
Parameters
model: The ANAMModel to train. Modified in-place. config: Training hyperparameters.
124 def __init__(self, model: ANAMModel, config: TrainingConfig) -> None: 125 self.model = model 126 self.config = config 127 128 if config.device is None: 129 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 130 else: 131 self.device = torch.device(config.device) 132 133 self.model.to(self.device) 134 self.history = TrainingHistory()
136 def fit( 137 self, 138 X: np.ndarray, 139 y: np.ndarray, 140 exposure: Optional[np.ndarray] = None, 141 ) -> TrainingHistory: 142 """Train the model. 143 144 Parameters 145 ---------- 146 X: 147 Feature matrix, shape (n, n_features). Continuous features 148 should be normalised before calling fit(). 149 y: 150 Target vector, shape (n,). Claim counts, rates, or severities. 151 exposure: 152 Exposure vector, shape (n,). Policy years or similar. If None, 153 uniform exposure (all 1.0) is assumed. 154 155 Returns 156 ------- 157 TrainingHistory 158 Training and validation loss per epoch. 159 """ 160 n = len(y) 161 162 if exposure is None: 163 exposure = np.ones(n, dtype=np.float32) 164 165 # Convert to tensors 166 X_t = torch.tensor(X, dtype=torch.float32) 167 y_t = torch.tensor(y, dtype=torch.float32) 168 exp_t = torch.tensor(exposure, dtype=torch.float32) 169 170 # Train/val split 171 n_val = max(1, int(n * self.config.val_fraction)) 172 perm = torch.randperm(n) 173 val_idx = perm[:n_val] 174 train_idx = perm[n_val:] 175 176 X_train, y_train, exp_train = X_t[train_idx], y_t[train_idx], exp_t[train_idx] 177 X_val, y_val, exp_val = X_t[val_idx], y_t[val_idx], exp_t[val_idx] 178 179 # Precompute feature ranges for smoothness penalty 180 feature_ranges = self._compute_feature_ranges(X_train) 181 182 # DataLoader 183 train_dataset = TensorDataset(X_train, y_train, exp_train) 184 train_loader = DataLoader( 185 train_dataset, 186 batch_size=self.config.batch_size, 187 shuffle=True, 188 drop_last=False, 189 ) 190 191 # Optimiser and LR scheduler 192 optimizer = torch.optim.Adam( 193 self.model.parameters(), 194 lr=self.config.learning_rate, 195 weight_decay=0.0, # L2 handled manually for flexibility 196 ) 197 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 198 optimizer, T_max=self.config.n_epochs, eta_min=self.config.learning_rate * 0.01 199 ) 200 201 best_val_loss = float("inf") 202 best_state = copy.deepcopy(self.model.state_dict()) 203 patience_counter = 0 204 205 for epoch in range(self.config.n_epochs): 206 t0 = time.time() 207 train_loss = self._train_epoch( 208 train_loader, optimizer, feature_ranges 209 ) 210 val_loss = self._evaluate(X_val, y_val, exp_val) 211 212 scheduler.step() 213 214 self.history.train_loss.append(train_loss) 215 self.history.val_loss.append(val_loss) 216 self.history.epoch_times.append(time.time() - t0) 217 218 if self.config.verbose > 0 and (epoch + 1) % self.config.verbose == 0: 219 print( 220 f"Epoch {epoch + 1:4d}/{self.config.n_epochs} | " 221 f"train={train_loss:.6f} | val={val_loss:.6f}" 222 ) 223 224 # Early stopping 225 if val_loss < best_val_loss - self.config.min_delta: 226 best_val_loss = val_loss 227 best_state = copy.deepcopy(self.model.state_dict()) 228 self.history.best_epoch = epoch 229 patience_counter = 0 230 else: 231 patience_counter += 1 232 if patience_counter >= self.config.patience: 233 if self.config.verbose > 0: 234 print(f"Early stopping at epoch {epoch + 1}.") 235 self.history.stopped_early = True 236 break 237 238 # Restore best weights 239 self.model.load_state_dict(best_state) 240 return self.history
Train the model.
Parameters
X: Feature matrix, shape (n, n_features). Continuous features should be normalised before calling fit(). y: Target vector, shape (n,). Claim counts, rates, or severities. exposure: Exposure vector, shape (n,). Policy years or similar. If None, uniform exposure (all 1.0) is assumed.
Returns
TrainingHistory Training and validation loss per epoch.
48@dataclass 49class TrainingConfig: 50 """Hyperparameters for the ANAM training loop. 51 52 Parameters 53 ---------- 54 loss: 55 Distributional loss type. 'poisson' for frequency, 'tweedie' for 56 pure premium, 'gamma' for severity, 'mse' for Gaussian. 57 tweedie_p: 58 Tweedie power parameter. Only used when loss='tweedie'. 59 Typical range (1.0, 2.0). Common choices: 1.5 (compound Poisson). 60 n_epochs: 61 Maximum training epochs. 62 batch_size: 63 Mini-batch size. 64 learning_rate: 65 Initial Adam learning rate. 66 lambda_smooth: 67 Smoothness penalty weight. 0.0 disables smoothness regularisation. 68 lambda_l1: 69 L1 sparsity penalty weight. 0.0 disables. 70 lambda_l2: 71 L2 ridge penalty weight (weight decay). 72 smooth_n_points: 73 Number of grid points for smoothness penalty evaluation. 74 val_fraction: 75 Fraction of training data held out for early stopping. 76 patience: 77 Number of epochs without improvement before stopping. 78 min_delta: 79 Minimum improvement in validation loss to count as progress. 80 verbose: 81 Print training progress every `verbose` epochs. 0 = silent. 82 device: 83 PyTorch device string. 'cpu' or 'cuda'. Auto-detected if None. 84 """ 85 86 loss: LossType = "poisson" 87 tweedie_p: float = 1.5 88 n_epochs: int = 100 89 batch_size: int = 512 90 learning_rate: float = 1e-3 91 lambda_smooth: float = 1e-4 92 lambda_l1: float = 0.0 93 lambda_l2: float = 1e-4 94 smooth_n_points: int = 100 95 val_fraction: float = 0.1 96 patience: int = 10 97 min_delta: float = 1e-6 98 verbose: int = 10 99 device: Optional[str] = None
Hyperparameters for the ANAM training loop.
Parameters
loss:
Distributional loss type. 'poisson' for frequency, 'tweedie' for
pure premium, 'gamma' for severity, 'mse' for Gaussian.
tweedie_p:
Tweedie power parameter. Only used when loss='tweedie'.
Typical range (1.0, 2.0). Common choices: 1.5 (compound Poisson).
n_epochs:
Maximum training epochs.
batch_size:
Mini-batch size.
learning_rate:
Initial Adam learning rate.
lambda_smooth:
Smoothness penalty weight. 0.0 disables smoothness regularisation.
lambda_l1:
L1 sparsity penalty weight. 0.0 disables.
lambda_l2:
L2 ridge penalty weight (weight decay).
smooth_n_points:
Number of grid points for smoothness penalty evaluation.
val_fraction:
Fraction of training data held out for early stopping.
patience:
Number of epochs without improvement before stopping.
min_delta:
Minimum improvement in validation loss to count as progress.
verbose:
Print training progress every verbose epochs. 0 = silent.
device:
PyTorch device string. 'cpu' or 'cuda'. Auto-detected if None.
102@dataclass 103class TrainingHistory: 104 """Records per-epoch training and validation losses.""" 105 106 train_loss: List[float] = field(default_factory=list) 107 val_loss: List[float] = field(default_factory=list) 108 epoch_times: List[float] = field(default_factory=list) 109 best_epoch: int = 0 110 stopped_early: bool = False
Records per-epoch training and validation losses.
36def poisson_deviance( 37 y_pred: torch.Tensor, 38 y_true: torch.Tensor, 39 weights: torch.Tensor | None = None, 40 eps: float = 1e-8, 41) -> torch.Tensor: 42 """Poisson deviance loss (mean over batch). 43 44 Deviance = 2 * w * [y * log(y / mu) - (y - mu)] 45 where mu = y_pred, y = y_true, w = observation weight. 46 47 Parameters 48 ---------- 49 y_pred: 50 Predicted mean (mu), strictly positive. Shape (batch,). 51 y_true: 52 Observed values. Shape (batch,). 53 weights: 54 Observation weights (e.g. exposure). Shape (batch,). If None, 55 uses uniform weights. 56 eps: 57 Numerical floor for y_true in log computation. 58 59 Returns 60 ------- 61 torch.Tensor 62 Scalar mean deviance. 63 """ 64 mu = y_pred.clamp(min=eps) 65 y = y_true.clamp(min=eps) 66 67 # y * log(y/mu) term: undefined at y=0, use 0 * log(0) = 0 convention 68 log_term = torch.where( 69 y_true > eps, 70 y * (torch.log(y) - torch.log(mu)), 71 torch.zeros_like(y), 72 ) 73 d = 2.0 * (log_term - (y_true - mu)) 74 75 if weights is not None: 76 return (weights * d).sum() / weights.sum().clamp(min=eps) 77 return d.mean()
Poisson deviance loss (mean over batch).
Deviance = 2 * w * [y * log(y / mu) - (y - mu)] where mu = y_pred, y = y_true, w = observation weight.
Parameters
y_pred: Predicted mean (mu), strictly positive. Shape (batch,). y_true: Observed values. Shape (batch,). weights: Observation weights (e.g. exposure). Shape (batch,). If None, uses uniform weights. eps: Numerical floor for y_true in log computation.
Returns
torch.Tensor Scalar mean deviance.
80def gamma_deviance( 81 y_pred: torch.Tensor, 82 y_true: torch.Tensor, 83 weights: torch.Tensor | None = None, 84 eps: float = 1e-8, 85) -> torch.Tensor: 86 """Gamma deviance loss (mean over batch). 87 88 Deviance = 2 * w * [log(mu/y) + (y - mu)/mu] 89 where mu = y_pred, y = y_true. 90 91 Used for claim severity (positive, right-skewed). 92 """ 93 mu = y_pred.clamp(min=eps) 94 y = y_true.clamp(min=eps) 95 96 d = 2.0 * (torch.log(mu / y) + (y - mu) / mu) 97 98 if weights is not None: 99 return (weights * d).sum() / weights.sum().clamp(min=eps) 100 return d.mean()
Gamma deviance loss (mean over batch).
Deviance = 2 * w * [log(mu/y) + (y - mu)/mu] where mu = y_pred, y = y_true.
Used for claim severity (positive, right-skewed).
103def tweedie_deviance( 104 y_pred: torch.Tensor, 105 y_true: torch.Tensor, 106 p: float = 1.5, 107 weights: torch.Tensor | None = None, 108 eps: float = 1e-8, 109) -> torch.Tensor: 110 """Tweedie deviance loss (mean over batch). 111 112 For p in (1, 2): 113 D(y, mu) = 2 * [y^(2-p)/((1-p)*(2-p)) - y*mu^(1-p)/(1-p) + mu^(2-p)/(2-p)] 114 115 Special cases: 116 - p=1: Poisson (use poisson_deviance for numerical stability) 117 - p=2: Gamma (use gamma_deviance) 118 - p=1.5: Inverse Gaussian-like, common for pure premium 119 120 Parameters 121 ---------- 122 p: 123 Tweedie power parameter. Must not equal 1 or 2. Typical range 124 (1.0, 2.0) for compound Poisson-Gamma. 125 """ 126 if abs(p - 1.0) < 1e-6: 127 return poisson_deviance(y_pred, y_true, weights, eps) 128 if abs(p - 2.0) < 1e-6: 129 return gamma_deviance(y_pred, y_true, weights, eps) 130 131 mu = y_pred.clamp(min=eps) 132 y = y_true.clamp(min=eps) 133 134 term1 = y.pow(2 - p) / ((1 - p) * (2 - p)) 135 term2 = y * mu.pow(1 - p) / (1 - p) 136 term3 = mu.pow(2 - p) / (2 - p) 137 138 # Handle y=0 in term1: 0^(2-p) for p<2 is 0, so term1 -> 0 139 term1 = torch.where(y_true < eps, torch.zeros_like(term1), term1) 140 141 d = 2.0 * (term1 - term2 + term3) 142 143 if weights is not None: 144 return (weights * d).sum() / weights.sum().clamp(min=eps) 145 return d.mean()
Tweedie deviance loss (mean over batch).
For p in (1, 2): D(y, mu) = 2 * [y^(2-p)/((1-p)(2-p)) - ymu^(1-p)/(1-p) + mu^(2-p)/(2-p)]
Special cases:
- p=1: Poisson (use poisson_deviance for numerical stability)
- p=2: Gamma (use gamma_deviance)
- p=1.5: Inverse Gaussian-like, common for pure premium
Parameters
p: Tweedie power parameter. Must not equal 1 or 2. Typical range (1.0, 2.0) for compound Poisson-Gamma.
148def bernoulli_deviance( 149 y_pred_logit: torch.Tensor, 150 y_true: torch.Tensor, 151 weights: torch.Tensor | None = None, 152 eps: float = 1e-8, 153) -> torch.Tensor: 154 """Binary cross-entropy deviance (logit inputs). 155 156 For binary outcomes (lapse, catastrophic event indicators). 157 y_pred_logit is the raw network output (before sigmoid). 158 """ 159 d = F.binary_cross_entropy_with_logits(y_pred_logit, y_true, reduction="none") 160 if weights is not None: 161 return (weights * d).sum() / weights.sum().clamp(min=eps) 162 return d.mean()
Binary cross-entropy deviance (logit inputs).
For binary outcomes (lapse, catastrophic event indicators). y_pred_logit is the raw network output (before sigmoid).
170def smoothness_penalty( 171 feature_network: "torch.nn.Module", 172 x_min: float, 173 x_max: float, 174 n_points: int = 100, 175 lambda_smooth: float = 1e-4, 176) -> torch.Tensor: 177 """Second-order difference penalty on a feature network's shape. 178 179 Evaluates f_i at n_points evenly spaced over [x_min, x_max] and 180 penalises second differences: sum((f_{k+2} - 2*f_{k+1} + f_k)^2). 181 182 This discourages shape functions that change direction rapidly. 183 Lambda_smooth controls the trade-off between fit and smoothness. 184 """ 185 # Determine device from network parameters so the grid lands on the same 186 # device as the model (GPU or CPU). torch.linspace defaults to CPU which 187 # would cause device-mismatch errors when the network is on CUDA. 188 try: 189 device = next(feature_network.parameters()).device 190 except StopIteration: 191 device = torch.device("cpu") 192 x_grid = torch.linspace(x_min, x_max, n_points, device=device) 193 f_vals = feature_network(x_grid).squeeze(-1) # (n_points,) 194 195 # Second-order differences: f[k+2] - 2*f[k+1] + f[k] 196 second_diff = f_vals[2:] - 2 * f_vals[1:-1] + f_vals[:-2] 197 penalty = lambda_smooth * (second_diff ** 2).sum() 198 return penalty
Second-order difference penalty on a feature network's shape.
Evaluates f_i at n_points evenly spaced over [x_min, x_max] and penalises second differences: sum((f_{k+2} - 2*f_{k+1} + f_k)^2).
This discourages shape functions that change direction rapidly. Lambda_smooth controls the trade-off between fit and smoothness.
201def l1_sparsity_penalty( 202 feature_networks: list["torch.nn.Module"], 203 lambda_l1: float = 1e-5, 204) -> torch.Tensor: 205 """L1 penalty on output layer weights of each subnetwork. 206 207 Encourages some subnetworks to output near-zero (feature selection). 208 Applied only to the output layer to avoid over-shrinking intermediate 209 representations. 210 """ 211 penalty = torch.tensor(0.0) 212 for net in feature_networks: 213 # Collect only the output layer: the last nn.Linear in the network. 214 # Penalising all weight layers would shrink intermediate representations 215 # too aggressively — the docstring is clear that only the output layer 216 # is the right target for feature-selection sparsity. 217 import torch.nn as nn 218 linear_layers = [m for m in net.modules() if isinstance(m, nn.Linear)] 219 if linear_layers: 220 output_layer = linear_layers[-1] 221 penalty = penalty + output_layer.weight.abs().sum() 222 return lambda_l1 * penalty
L1 penalty on output layer weights of each subnetwork.
Encourages some subnetworks to output near-zero (feature selection). Applied only to the output layer to avoid over-shrinking intermediate representations.
225def l2_ridge_penalty( 226 feature_networks: list["torch.nn.Module"], 227 lambda_l2: float = 1e-4, 228) -> torch.Tensor: 229 """L2 ridge penalty across all subnetwork weights. 230 231 Standard weight decay. Stabilises training especially when many 232 subnetworks sum together — without it, individual nets can grow 233 large while cancelling each other out. 234 """ 235 penalty = torch.tensor(0.0) 236 for net in feature_networks: 237 for param in net.parameters(): 238 penalty = penalty + (param ** 2).sum() 239 return lambda_l2 * penalty
L2 ridge penalty across all subnetwork weights.
Standard weight decay. Stabilises training especially when many subnetworks sum together — without it, individual nets can grow large while cancelling each other out.
33@dataclass 34class ShapeFunction: 35 """Extracted shape function for one feature. 36 37 Contains the evaluated curve (x_values, f_values) and metadata for 38 reporting. Created by ANAM.shape_functions() after fitting. 39 40 Attributes 41 ---------- 42 feature_name: 43 Feature identifier. 44 feature_type: 45 'continuous' or 'categorical'. 46 x_values: 47 Grid of feature values (continuous) or category indices (categorical). 48 f_values: 49 Corresponding subnetwork outputs f_i(x_i). 50 x_label: 51 Human-readable x-axis label for plots. 52 monotonicity: 53 Monotonicity constraint applied during training. 54 category_labels: 55 Optional mapping from category index to label string. 56 """ 57 58 feature_name: str 59 feature_type: str 60 x_values: np.ndarray 61 f_values: np.ndarray 62 x_label: str = "" 63 monotonicity: str = "none" 64 category_labels: Optional[Dict[int, str]] = None 65 66 def to_polars(self) -> pl.DataFrame: 67 """Export shape function as a Polars DataFrame. 68 69 For continuous features: columns [x, f_x] 70 For categorical features: columns [category_index, category_label, f_x] 71 """ 72 if self.feature_type == "continuous": 73 return pl.DataFrame( 74 { 75 "x": self.x_values.tolist(), 76 "f_x": self.f_values.tolist(), 77 "feature": [self.feature_name] * len(self.x_values), 78 } 79 ) 80 else: 81 labels = [ 82 self.category_labels.get(int(i), str(int(i))) 83 if self.category_labels 84 else str(int(i)) 85 for i in self.x_values 86 ] 87 return pl.DataFrame( 88 { 89 "category_index": self.x_values.astype(int).tolist(), 90 "category_label": labels, 91 "f_x": self.f_values.tolist(), 92 "feature": [self.feature_name] * len(self.x_values), 93 } 94 ) 95 96 def to_relativities(self, base_level: Optional[float] = None) -> pl.DataFrame: 97 """Convert shape function to GLM-style multiplicative relativities. 98 99 For log-link models, the shape function is on the log scale. 100 exp(f_i(x_i)) gives the multiplicative factor. This is then 101 normalised by the value at the base level (default: median x). 102 103 Parameters 104 ---------- 105 base_level: 106 Feature value to use as base (relativity = 1.0). For 107 continuous features, defaults to median. For categorical, 108 defaults to category with f_x closest to zero. 109 110 Returns 111 ------- 112 pl.DataFrame 113 Columns: [x, relativity, log_relativity] 114 """ 115 if self.feature_type == "continuous": 116 if base_level is None: 117 mid_idx = len(self.x_values) // 2 118 base_f = self.f_values[mid_idx] 119 else: 120 # Find nearest grid point 121 idx = np.argmin(np.abs(self.x_values - base_level)) 122 base_f = self.f_values[idx] 123 124 log_rel = self.f_values - base_f 125 rel = np.exp(log_rel) 126 127 return pl.DataFrame( 128 { 129 "x": self.x_values.tolist(), 130 "relativity": rel.tolist(), 131 "log_relativity": log_rel.tolist(), 132 "feature": [self.feature_name] * len(self.x_values), 133 } 134 ) 135 else: 136 if base_level is None: 137 base_f = self.f_values[np.argmin(np.abs(self.f_values))] 138 else: 139 base_f = self.f_values[int(base_level)] 140 141 log_rel = self.f_values - base_f 142 rel = np.exp(log_rel) 143 labels = [ 144 self.category_labels.get(int(i), str(int(i))) 145 if self.category_labels 146 else str(int(i)) 147 for i in self.x_values 148 ] 149 150 return pl.DataFrame( 151 { 152 "category_index": self.x_values.astype(int).tolist(), 153 "category_label": labels, 154 "relativity": rel.tolist(), 155 "log_relativity": log_rel.tolist(), 156 "feature": [self.feature_name] * len(self.x_values), 157 } 158 ) 159 160 def to_dict(self) -> Dict[str, Any]: 161 """Serialise to a plain dict (JSON-compatible).""" 162 return { 163 "feature_name": self.feature_name, 164 "feature_type": self.feature_type, 165 "monotonicity": self.monotonicity, 166 "x_values": self.x_values.tolist(), 167 "f_values": self.f_values.tolist(), 168 "x_label": self.x_label, 169 "category_labels": self.category_labels, 170 } 171 172 def to_json(self, indent: int = 2) -> str: 173 """Serialise to JSON string.""" 174 return json.dumps(self.to_dict(), indent=indent) 175 176 def plot( 177 self, 178 ax: Optional[Any] = None, 179 show_monotonicity: bool = True, 180 title: Optional[str] = None, 181 figsize: Tuple[int, int] = (8, 4), 182 ) -> Any: 183 """Plot the shape function. 184 185 Parameters 186 ---------- 187 ax: 188 Matplotlib axes object. Creates a new figure if None. 189 show_monotonicity: 190 Annotate the plot with the monotonicity constraint. 191 title: 192 Plot title. Defaults to feature name. 193 194 Returns 195 ------- 196 matplotlib.axes.Axes 197 """ 198 import matplotlib.pyplot as plt 199 200 if ax is None: 201 _, ax = plt.subplots(figsize=figsize) 202 203 if self.feature_type == "continuous": 204 ax.plot(self.x_values, self.f_values, linewidth=2, color="#2c6fad") 205 ax.axhline(0, color="gray", linewidth=0.8, linestyle="--", alpha=0.5) 206 ax.fill_between( 207 self.x_values, self.f_values, 0, 208 alpha=0.15, color="#2c6fad" 209 ) 210 ax.set_xlabel(self.x_label or self.feature_name) 211 ax.set_ylabel("log contribution") 212 213 if show_monotonicity and self.monotonicity != "none": 214 mono_label = f"monotone {self.monotonicity}" 215 ax.annotate( 216 mono_label, 217 xy=(0.02, 0.95), 218 xycoords="axes fraction", 219 fontsize=9, 220 color="darkgreen", 221 va="top", 222 ) 223 224 else: 225 # Categorical: bar chart 226 labels = [ 227 self.category_labels.get(int(i), str(int(i))) 228 if self.category_labels 229 else str(int(i)) 230 for i in self.x_values 231 ] 232 colors = [ 233 "#d63031" if v < 0 else "#0984e3" for v in self.f_values 234 ] 235 ax.bar(labels, self.f_values, color=colors, edgecolor="white") 236 ax.axhline(0, color="gray", linewidth=0.8, linestyle="--") 237 ax.set_xlabel(self.x_label or self.feature_name) 238 ax.set_ylabel("log contribution") 239 ax.tick_params(axis="x", rotation=45) 240 241 ax.set_title(title or f"Shape function: {self.feature_name}") 242 ax.spines["top"].set_visible(False) 243 ax.spines["right"].set_visible(False) 244 245 return ax
Extracted shape function for one feature.
Contains the evaluated curve (x_values, f_values) and metadata for reporting. Created by ANAM.shape_functions() after fitting.
Attributes
feature_name: Feature identifier. feature_type: 'continuous' or 'categorical'. x_values: Grid of feature values (continuous) or category indices (categorical). f_values: Corresponding subnetwork outputs f_i(x_i). x_label: Human-readable x-axis label for plots. monotonicity: Monotonicity constraint applied during training. category_labels: Optional mapping from category index to label string.
66 def to_polars(self) -> pl.DataFrame: 67 """Export shape function as a Polars DataFrame. 68 69 For continuous features: columns [x, f_x] 70 For categorical features: columns [category_index, category_label, f_x] 71 """ 72 if self.feature_type == "continuous": 73 return pl.DataFrame( 74 { 75 "x": self.x_values.tolist(), 76 "f_x": self.f_values.tolist(), 77 "feature": [self.feature_name] * len(self.x_values), 78 } 79 ) 80 else: 81 labels = [ 82 self.category_labels.get(int(i), str(int(i))) 83 if self.category_labels 84 else str(int(i)) 85 for i in self.x_values 86 ] 87 return pl.DataFrame( 88 { 89 "category_index": self.x_values.astype(int).tolist(), 90 "category_label": labels, 91 "f_x": self.f_values.tolist(), 92 "feature": [self.feature_name] * len(self.x_values), 93 } 94 )
Export shape function as a Polars DataFrame.
For continuous features: columns [x, f_x] For categorical features: columns [category_index, category_label, f_x]
96 def to_relativities(self, base_level: Optional[float] = None) -> pl.DataFrame: 97 """Convert shape function to GLM-style multiplicative relativities. 98 99 For log-link models, the shape function is on the log scale. 100 exp(f_i(x_i)) gives the multiplicative factor. This is then 101 normalised by the value at the base level (default: median x). 102 103 Parameters 104 ---------- 105 base_level: 106 Feature value to use as base (relativity = 1.0). For 107 continuous features, defaults to median. For categorical, 108 defaults to category with f_x closest to zero. 109 110 Returns 111 ------- 112 pl.DataFrame 113 Columns: [x, relativity, log_relativity] 114 """ 115 if self.feature_type == "continuous": 116 if base_level is None: 117 mid_idx = len(self.x_values) // 2 118 base_f = self.f_values[mid_idx] 119 else: 120 # Find nearest grid point 121 idx = np.argmin(np.abs(self.x_values - base_level)) 122 base_f = self.f_values[idx] 123 124 log_rel = self.f_values - base_f 125 rel = np.exp(log_rel) 126 127 return pl.DataFrame( 128 { 129 "x": self.x_values.tolist(), 130 "relativity": rel.tolist(), 131 "log_relativity": log_rel.tolist(), 132 "feature": [self.feature_name] * len(self.x_values), 133 } 134 ) 135 else: 136 if base_level is None: 137 base_f = self.f_values[np.argmin(np.abs(self.f_values))] 138 else: 139 base_f = self.f_values[int(base_level)] 140 141 log_rel = self.f_values - base_f 142 rel = np.exp(log_rel) 143 labels = [ 144 self.category_labels.get(int(i), str(int(i))) 145 if self.category_labels 146 else str(int(i)) 147 for i in self.x_values 148 ] 149 150 return pl.DataFrame( 151 { 152 "category_index": self.x_values.astype(int).tolist(), 153 "category_label": labels, 154 "relativity": rel.tolist(), 155 "log_relativity": log_rel.tolist(), 156 "feature": [self.feature_name] * len(self.x_values), 157 } 158 )
Convert shape function to GLM-style multiplicative relativities.
For log-link models, the shape function is on the log scale. exp(f_i(x_i)) gives the multiplicative factor. This is then normalised by the value at the base level (default: median x).
Parameters
base_level: Feature value to use as base (relativity = 1.0). For continuous features, defaults to median. For categorical, defaults to category with f_x closest to zero.
Returns
pl.DataFrame Columns: [x, relativity, log_relativity]
160 def to_dict(self) -> Dict[str, Any]: 161 """Serialise to a plain dict (JSON-compatible).""" 162 return { 163 "feature_name": self.feature_name, 164 "feature_type": self.feature_type, 165 "monotonicity": self.monotonicity, 166 "x_values": self.x_values.tolist(), 167 "f_values": self.f_values.tolist(), 168 "x_label": self.x_label, 169 "category_labels": self.category_labels, 170 }
Serialise to a plain dict (JSON-compatible).
172 def to_json(self, indent: int = 2) -> str: 173 """Serialise to JSON string.""" 174 return json.dumps(self.to_dict(), indent=indent)
Serialise to JSON string.
176 def plot( 177 self, 178 ax: Optional[Any] = None, 179 show_monotonicity: bool = True, 180 title: Optional[str] = None, 181 figsize: Tuple[int, int] = (8, 4), 182 ) -> Any: 183 """Plot the shape function. 184 185 Parameters 186 ---------- 187 ax: 188 Matplotlib axes object. Creates a new figure if None. 189 show_monotonicity: 190 Annotate the plot with the monotonicity constraint. 191 title: 192 Plot title. Defaults to feature name. 193 194 Returns 195 ------- 196 matplotlib.axes.Axes 197 """ 198 import matplotlib.pyplot as plt 199 200 if ax is None: 201 _, ax = plt.subplots(figsize=figsize) 202 203 if self.feature_type == "continuous": 204 ax.plot(self.x_values, self.f_values, linewidth=2, color="#2c6fad") 205 ax.axhline(0, color="gray", linewidth=0.8, linestyle="--", alpha=0.5) 206 ax.fill_between( 207 self.x_values, self.f_values, 0, 208 alpha=0.15, color="#2c6fad" 209 ) 210 ax.set_xlabel(self.x_label or self.feature_name) 211 ax.set_ylabel("log contribution") 212 213 if show_monotonicity and self.monotonicity != "none": 214 mono_label = f"monotone {self.monotonicity}" 215 ax.annotate( 216 mono_label, 217 xy=(0.02, 0.95), 218 xycoords="axes fraction", 219 fontsize=9, 220 color="darkgreen", 221 va="top", 222 ) 223 224 else: 225 # Categorical: bar chart 226 labels = [ 227 self.category_labels.get(int(i), str(int(i))) 228 if self.category_labels 229 else str(int(i)) 230 for i in self.x_values 231 ] 232 colors = [ 233 "#d63031" if v < 0 else "#0984e3" for v in self.f_values 234 ] 235 ax.bar(labels, self.f_values, color=colors, edgecolor="white") 236 ax.axhline(0, color="gray", linewidth=0.8, linestyle="--") 237 ax.set_xlabel(self.x_label or self.feature_name) 238 ax.set_ylabel("log contribution") 239 ax.tick_params(axis="x", rotation=45) 240 241 ax.set_title(title or f"Shape function: {self.feature_name}") 242 ax.spines["top"].set_visible(False) 243 ax.spines["right"].set_visible(False) 244 245 return ax
Plot the shape function.
Parameters
ax: Matplotlib axes object. Creates a new figure if None. show_monotonicity: Annotate the plot with the monotonicity constraint. title: Plot title. Defaults to feature name.
Returns
matplotlib.axes.Axes
248def extract_shape_functions( 249 model: "ANAMModel", 250 X_train: np.ndarray, 251 n_points: int = 200, 252 category_labels: Optional[Dict[str, Dict[int, str]]] = None, 253) -> Dict[str, ShapeFunction]: 254 """Extract all shape functions from a trained model. 255 256 Evaluates each subnetwork over the observed range of the training data. 257 258 Parameters 259 ---------- 260 model: 261 Trained ANAMModel. 262 X_train: 263 Training feature matrix. Used to determine observed feature ranges 264 and category sets. 265 n_points: 266 Number of grid points for continuous features. 267 category_labels: 268 Optional dict mapping feature_name -> {category_idx -> label string}. 269 270 Returns 271 ------- 272 Dict[str, ShapeFunction] 273 Feature name -> ShapeFunction for each feature. 274 """ 275 shapes: Dict[str, ShapeFunction] = {} 276 277 model.eval() 278 with torch.no_grad(): 279 for i, cfg in enumerate(model.feature_configs): 280 col = X_train[:, i] 281 net = model.feature_nets[cfg.name] 282 cat_labels = ( 283 category_labels.get(cfg.name) if category_labels else None 284 ) 285 286 if cfg.feature_type == "continuous": 287 x_min, x_max = float(col.min()), float(col.max()) 288 x_grid = torch.linspace(x_min, x_max, n_points) 289 f_vals = net(x_grid.unsqueeze(-1)).squeeze(-1).cpu().numpy() 290 x_vals = x_grid.cpu().numpy() 291 292 else: 293 unique_cats = np.unique(col.astype(int)) 294 x_cat = torch.tensor(unique_cats, dtype=torch.long) 295 f_vals = net(x_cat).squeeze(-1).cpu().numpy() 296 x_vals = unique_cats.astype(float) 297 298 shapes[cfg.name] = ShapeFunction( 299 feature_name=cfg.name, 300 feature_type=cfg.feature_type, 301 x_values=x_vals, 302 f_values=f_vals, 303 x_label=cfg.name, 304 monotonicity=cfg.monotonicity if hasattr(cfg, "monotonicity") else "none", 305 category_labels=cat_labels, 306 ) 307 308 return shapes
Extract all shape functions from a trained model.
Evaluates each subnetwork over the observed range of the training data.
Parameters
model: Trained ANAMModel. X_train: Training feature matrix. Used to determine observed feature ranges and category sets. n_points: Number of grid points for continuous features. category_labels: Optional dict mapping feature_name -> {category_idx -> label string}.
Returns
Dict[str, ShapeFunction] Feature name -> ShapeFunction for each feature.
311def plot_all_shapes( 312 shapes: Dict[str, ShapeFunction], 313 n_cols: int = 3, 314 figsize_per_plot: Tuple[int, int] = (5, 3), 315 suptitle: str = "ANAM Shape Functions", 316) -> Any: 317 """Plot all shape functions in a grid layout. 318 319 Parameters 320 ---------- 321 shapes: 322 Dict returned by extract_shape_functions(). 323 n_cols: 324 Number of columns in the subplot grid. 325 figsize_per_plot: 326 Width and height per subplot panel. 327 suptitle: 328 Overall figure title. 329 330 Returns 331 ------- 332 matplotlib.figure.Figure 333 """ 334 import matplotlib.pyplot as plt 335 336 n = len(shapes) 337 n_rows = (n + n_cols - 1) // n_cols 338 fig_w = figsize_per_plot[0] * n_cols 339 fig_h = figsize_per_plot[1] * n_rows 340 341 fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h)) 342 if n == 1: 343 axes = np.array([[axes]]) 344 elif n_rows == 1: 345 axes = axes.reshape(1, -1) 346 elif n_cols == 1: 347 axes = axes.reshape(-1, 1) 348 349 axes_flat = axes.flatten() 350 351 for ax, (name, sf) in zip(axes_flat, shapes.items()): 352 sf.plot(ax=ax) 353 354 # Hide unused subplots 355 for ax in axes_flat[n:]: 356 ax.set_visible(False) 357 358 fig.suptitle(suptitle, fontsize=13, y=1.01) 359 fig.tight_layout() 360 return fig
Plot all shape functions in a grid layout.
Parameters
shapes: Dict returned by extract_shape_functions(). n_cols: Number of columns in the subplot grid. figsize_per_plot: Width and height per subplot panel. suptitle: Overall figure title.
Returns
matplotlib.figure.Figure
33class StandardScaler: 34 """Simple StandardScaler that tracks feature names and ranges. 35 36 Stores mean/std for inverse transformation (needed to recover original 37 feature values when plotting shape functions). 38 """ 39 40 def __init__(self) -> None: 41 self.means_: Optional[np.ndarray] = None 42 self.stds_: Optional[np.ndarray] = None 43 self.feature_names_: Optional[List[str]] = None 44 45 def fit( 46 self, 47 X: np.ndarray, 48 feature_names: Optional[List[str]] = None, 49 ) -> "StandardScaler": 50 self.means_ = X.mean(axis=0) 51 self.stds_ = X.std(axis=0).clip(min=1e-8) 52 self.feature_names_ = feature_names or [f"f{i}" for i in range(X.shape[1])] 53 return self 54 55 def transform(self, X: np.ndarray) -> np.ndarray: 56 if self.means_ is None: 57 raise RuntimeError("Call fit() before transform().") 58 return (X - self.means_) / self.stds_ 59 60 def fit_transform(self, X: np.ndarray, feature_names: Optional[List[str]] = None) -> np.ndarray: 61 return self.fit(X, feature_names).transform(X) 62 63 def inverse_transform(self, X_scaled: np.ndarray) -> np.ndarray: 64 if self.means_ is None: 65 raise RuntimeError("Call fit() before inverse_transform().") 66 return X_scaled * self.stds_ + self.means_ 67 68 def inverse_transform_col(self, x_scaled: np.ndarray, col_idx: int) -> np.ndarray: 69 """Inverse transform a single column.""" 70 assert self.means_ is not None 71 return x_scaled * self.stds_[col_idx] + self.means_[col_idx]
Simple StandardScaler that tracks feature names and ranges.
Stores mean/std for inverse transformation (needed to recover original feature values when plotting shape functions).
68 def inverse_transform_col(self, x_scaled: np.ndarray, col_idx: int) -> np.ndarray: 69 """Inverse transform a single column.""" 70 assert self.means_ is not None 71 return x_scaled * self.stds_[col_idx] + self.means_[col_idx]
Inverse transform a single column.
79def select_interactions_correlation( 80 X: np.ndarray, 81 feature_names: List[str], 82 threshold: float = 0.3, 83 top_k: Optional[int] = 10, 84 exclude_categorical: Optional[List[int]] = None, 85) -> List[Tuple[str, str, float]]: 86 """Screen feature pairs for interaction candidates using Pearson correlation. 87 88 Pairs with |r| above threshold are candidates for interaction subnetworks. 89 Lower |r| means the two features are more independent in the feature space 90 — which paradoxically can mean their interaction is more surprising and 91 worth modelling. This heuristic selects the *highest* correlation pairs 92 because they often represent rate structures where the joint effect differs 93 from the sum of main effects. 94 95 Parameters 96 ---------- 97 X: 98 Feature matrix (n, p). Should be the continuous features only. 99 feature_names: 100 Names matching columns of X. 101 threshold: 102 Minimum |correlation| to include pair. Default 0.3. 103 top_k: 104 Maximum number of pairs to return. None = all above threshold. 105 exclude_categorical: 106 Column indices of categorical features to exclude from screening. 107 108 Returns 109 ------- 110 List[Tuple[str, str, float]] 111 List of (feature_i, feature_j, correlation) sorted by |r| descending. 112 """ 113 exclude = set(exclude_categorical or []) 114 p = X.shape[1] 115 116 candidates: List[Tuple[str, str, float]] = [] 117 118 for i in range(p): 119 if i in exclude: 120 continue 121 for j in range(i + 1, p): 122 if j in exclude: 123 continue 124 corr = float(np.corrcoef(X[:, i], X[:, j])[0, 1]) 125 if abs(corr) >= threshold: 126 candidates.append((feature_names[i], feature_names[j], corr)) 127 128 candidates.sort(key=lambda x: abs(x[2]), reverse=True) 129 130 if top_k is not None: 131 candidates = candidates[:top_k] 132 133 return candidates
Screen feature pairs for interaction candidates using Pearson correlation.
Pairs with |r| above threshold are candidates for interaction subnetworks. Lower |r| means the two features are more independent in the feature space — which paradoxically can mean their interaction is more surprising and worth modelling. This heuristic selects the highest correlation pairs because they often represent rate structures where the joint effect differs from the sum of main effects.
Parameters
X: Feature matrix (n, p). Should be the continuous features only. feature_names: Names matching columns of X. threshold: Minimum |correlation| to include pair. Default 0.3. top_k: Maximum number of pairs to return. None = all above threshold. exclude_categorical: Column indices of categorical features to exclude from screening.
Returns
List[Tuple[str, str, float]] List of (feature_i, feature_j, correlation) sorted by |r| descending.
136def select_interactions_residual( 137 X: np.ndarray, 138 y_residuals: np.ndarray, 139 feature_names: List[str], 140 top_k: int = 5, 141 exclude_categorical: Optional[List[int]] = None, 142) -> List[Tuple[str, str, float]]: 143 """Select interaction pairs by pairwise product correlation with residuals. 144 145 After fitting the additive model, residuals contain unexplained variance. 146 This method checks whether x_i * x_j correlates with residuals — if so, 147 the interaction term x_i x_j may be worth adding. 148 149 More principled than pure feature-space correlation for identifying 150 interactions that actually improve model fit. 151 152 Parameters 153 ---------- 154 X: 155 Feature matrix. 156 y_residuals: 157 Model residuals (y - predicted). 158 feature_names: 159 Feature names. 160 top_k: 161 Number of top pairs to return. 162 163 Returns 164 ------- 165 List[Tuple[str, str, float]] 166 Sorted by |correlation with residuals| descending. 167 """ 168 exclude = set(exclude_categorical or []) 169 p = X.shape[1] 170 candidates: List[Tuple[str, str, float]] = [] 171 172 for i in range(p): 173 if i in exclude: 174 continue 175 for j in range(i + 1, p): 176 if j in exclude: 177 continue 178 product = X[:, i] * X[:, j] 179 corr = float(np.corrcoef(product, y_residuals)[0, 1]) 180 candidates.append((feature_names[i], feature_names[j], abs(corr))) 181 182 candidates.sort(key=lambda x: x[2], reverse=True) 183 return candidates[:top_k]
Select interaction pairs by pairwise product correlation with residuals.
After fitting the additive model, residuals contain unexplained variance. This method checks whether x_i * x_j correlates with residuals — if so, the interaction term x_i x_j may be worth adding.
More principled than pure feature-space correlation for identifying interactions that actually improve model fit.
Parameters
X: Feature matrix. y_residuals: Model residuals (y - predicted). feature_names: Feature names. top_k: Number of top pairs to return.
Returns
List[Tuple[str, str, float]] Sorted by |correlation with residuals| descending.
191def shapes_to_relativity_table( 192 shapes: Dict[str, "ShapeFunction"], 193 feature_names: Optional[List[str]] = None, 194) -> pl.DataFrame: 195 """Aggregate all shape functions into a single relativity table. 196 197 Outputs a long-format DataFrame with one row per (feature, level) 198 combination, suitable for actuarial review in Excel. 199 200 Parameters 201 ---------- 202 shapes: 203 Dict returned by extract_shape_functions(). 204 feature_names: 205 Subset of feature names to include. None = all features. 206 207 Returns 208 ------- 209 pl.DataFrame 210 Columns: [feature, level, f_x, relativity, log_relativity] 211 """ 212 names = feature_names or list(shapes.keys()) 213 dfs: List[pl.DataFrame] = [] 214 215 for name in names: 216 if name not in shapes: 217 continue 218 sf = shapes[name] 219 rel_df = sf.to_relativities() 220 221 if sf.feature_type == "continuous": 222 level_col = pl.Series("level", [str(round(float(x), 4)) for x in sf.x_values]) 223 else: 224 level_col = pl.Series( 225 "level", 226 [ 227 sf.category_labels.get(int(i), str(int(i))) if sf.category_labels else str(int(i)) 228 for i in sf.x_values 229 ], 230 ) 231 232 feature_col = pl.Series("feature", [name] * len(sf.x_values)) 233 f_x_col = pl.Series("f_x", sf.f_values.tolist()) 234 235 rel_col = rel_df["relativity"] 236 log_rel_col = rel_df["log_relativity"] 237 238 dfs.append( 239 pl.DataFrame( 240 { 241 "feature": feature_col, 242 "level": level_col, 243 "f_x": f_x_col, 244 "relativity": rel_col, 245 "log_relativity": log_rel_col, 246 } 247 ) 248 ) 249 250 if not dfs: 251 return pl.DataFrame( 252 schema={ 253 "feature": pl.Utf8, 254 "level": pl.Utf8, 255 "f_x": pl.Float64, 256 "relativity": pl.Float64, 257 "log_relativity": pl.Float64, 258 } 259 ) 260 261 return pl.concat(dfs)
Aggregate all shape functions into a single relativity table.
Outputs a long-format DataFrame with one row per (feature, level) combination, suitable for actuarial review in Excel.
Parameters
shapes: Dict returned by extract_shape_functions(). feature_names: Subset of feature names to include. None = all features.
Returns
pl.DataFrame Columns: [feature, level, f_x, relativity, log_relativity]
264def compare_shapes_to_glm( 265 anam_shapes: Dict[str, "ShapeFunction"], 266 glm_coefficients: Dict[str, Dict[str, float]], 267) -> pl.DataFrame: 268 """Compare ANAM shape functions to GLM log-relativities. 269 270 For each feature level in the GLM, find the nearest ANAM shape function 271 value and compute the deviation. Useful for model validation against 272 existing production GLMs. 273 274 Parameters 275 ---------- 276 anam_shapes: 277 ANAM shape functions (from extract_shape_functions). 278 glm_coefficients: 279 Dict mapping feature_name -> {level_str -> log_relativity}. 280 GLM log-relativities in the same scale as ANAM f_i outputs. 281 282 Returns 283 ------- 284 pl.DataFrame 285 Columns: [feature, level, anam_f, glm_log_rel, deviation] 286 """ 287 rows: List[Dict] = [] 288 289 for feature_name, glm_levels in glm_coefficients.items(): 290 if feature_name not in anam_shapes: 291 continue 292 293 sf = anam_shapes[feature_name] 294 295 for level_str, glm_val in glm_levels.items(): 296 # Find ANAM value at this level 297 if sf.feature_type == "continuous": 298 try: 299 x_level = float(level_str) 300 idx = int(np.argmin(np.abs(sf.x_values - x_level))) 301 anam_val = float(sf.f_values[idx]) 302 except (ValueError, IndexError): 303 continue 304 else: 305 try: 306 cat_idx = int(level_str) 307 if cat_idx < len(sf.f_values): 308 anam_val = float(sf.f_values[cat_idx]) 309 else: 310 continue 311 except (ValueError, IndexError): 312 continue 313 314 rows.append( 315 { 316 "feature": feature_name, 317 "level": level_str, 318 "anam_f": anam_val, 319 "glm_log_rel": glm_val, 320 "deviation": anam_val - glm_val, 321 } 322 ) 323 324 if not rows: 325 return pl.DataFrame( 326 schema={ 327 "feature": pl.Utf8, 328 "level": pl.Utf8, 329 "anam_f": pl.Float64, 330 "glm_log_rel": pl.Float64, 331 "deviation": pl.Float64, 332 } 333 ) 334 335 return pl.DataFrame(rows)
Compare ANAM shape functions to GLM log-relativities.
For each feature level in the GLM, find the nearest ANAM shape function value and compute the deviation. Useful for model validation against existing production GLMs.
Parameters
anam_shapes: ANAM shape functions (from extract_shape_functions). glm_coefficients: Dict mapping feature_name -> {level_str -> log_relativity}. GLM log-relativities in the same scale as ANAM f_i outputs.
Returns
pl.DataFrame Columns: [feature, level, anam_f, glm_log_rel, deviation]
338def compute_deviance_stat( 339 y_true: np.ndarray, 340 y_pred: np.ndarray, 341 exposure: Optional[np.ndarray] = None, 342 loss: str = "poisson", 343 tweedie_p: float = 1.5, 344 eps: float = 1e-8, 345) -> float: 346 """Compute weighted deviance statistic for model comparison. 347 348 Returns the mean deviance (lower is better). Useful for comparing 349 ANAM to GLM/EBM baselines with a single number. 350 """ 351 import torch 352 353 y_t = torch.tensor(y_true, dtype=torch.float32) 354 yp_t = torch.tensor(y_pred, dtype=torch.float32) 355 w_t = torch.tensor(exposure, dtype=torch.float32) if exposure is not None else None 356 357 from .losses import gamma_deviance, poisson_deviance, tweedie_deviance 358 359 if loss == "poisson": 360 return float(poisson_deviance(yp_t, y_t, w_t, eps).item()) 361 elif loss == "tweedie": 362 return float(tweedie_deviance(yp_t, y_t, tweedie_p, w_t, eps).item()) 363 elif loss == "gamma": 364 return float(gamma_deviance(yp_t, y_t, w_t, eps).item()) 365 else: 366 raise ValueError(f"Unknown loss: {loss!r}")
Compute weighted deviance statistic for model comparison.
Returns the mean deviance (lower is better). Useful for comparing ANAM to GLM/EBM baselines with a single number.