insurance_gam.pin
insurance_gam.pin — Pairwise Interaction Networks subpackage.
Re-exports the full public API of the original insurance-pin package.
1""" 2insurance_gam.pin — Pairwise Interaction Networks subpackage. 3 4Re-exports the full public API of the original insurance-pin package. 5""" 6 7from .model import PINModel, PINEnsemble 8from .diagnostics import PINDiagnostics 9from .networks import centered_hard_sigmoid 10 11__all__ = [ 12 "PINModel", 13 "PINEnsemble", 14 "PINDiagnostics", 15 "centered_hard_sigmoid", 16]
60class PINModel(nn.Module): 61 """ 62 Single Tree-like Pairwise Interaction Network. 63 64 Prediction: 65 f_PIN(x) = exp( sum_{j<=k} w_{jk} * h_{jk}(x) + b ) 66 67 Args: 68 features: Dict mapping feature name to spec. Spec is 'continuous' or 69 an int (number of categories). Order matters — features are indexed 70 by position. 71 embedding_dim: Feature embedding dimension d (default 10). 72 hidden_dim: Hidden width for continuous embedding FNNs d' (default 20). 73 token_dim: Interaction token dimension d0 (default 10). 74 shared_dims: (d1, d2) widths for shared interaction network (default [30, 20]). 75 loss: Loss name — 'poisson', 'gamma', or 'tweedie'. 76 tweedie_p: Tweedie power (only used when loss='tweedie'). 77 lr: Adam learning rate (default 0.001). 78 batch_size: Mini-batch size (default 128). 79 max_epochs: Maximum training epochs (default 500). 80 patience: Early stopping patience in epochs (default 20). 81 lr_patience: ReduceLROnPlateau patience (default 5). 82 lr_factor: ReduceLROnPlateau reduction factor (default 0.9). 83 val_fraction: Fraction of training data for validation if X_val not given 84 (default 0.1). 85 device: Torch device string, or None to auto-detect. 86 random_seed: Seed for reproducibility. 87 88 Examples: 89 >>> model = PINModel( 90 ... features={"age": "continuous", "area": 5}, 91 ... loss="poisson", 92 ... ) 93 """ 94 95 def __init__( 96 self, 97 features: FeatureSpec, 98 embedding_dim: int = 10, 99 hidden_dim: int = 20, 100 token_dim: int = 10, 101 shared_dims: Tuple[int, int] = (30, 20), 102 loss: str = "poisson", 103 tweedie_p: float = 1.5, 104 lr: float = 0.001, 105 batch_size: int = 128, 106 max_epochs: int = 500, 107 patience: int = 20, 108 lr_patience: int = 5, 109 lr_factor: float = 0.9, 110 val_fraction: float = 0.1, 111 device: Optional[str] = None, 112 random_seed: int = 42, 113 ) -> None: 114 super().__init__() 115 116 self.features = dict(features) 117 self.feature_names: List[str] = list(features.keys()) 118 self.q = len(self.feature_names) 119 self.embedding_dim = embedding_dim 120 self.hidden_dim = hidden_dim 121 self.token_dim = token_dim 122 self.shared_dims = tuple(shared_dims) 123 self.loss_name = loss 124 self.tweedie_p = tweedie_p 125 self.lr = lr 126 self.batch_size = batch_size 127 self.max_epochs = max_epochs 128 self.patience = patience 129 self.lr_patience = lr_patience 130 self.lr_factor = lr_factor 131 self.val_fraction = val_fraction 132 self.random_seed = random_seed 133 134 # Resolve device 135 if device is None: 136 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 137 else: 138 self._device = torch.device(device) 139 140 # --- Sub-modules --- 141 self.feature_embeddings = FeatureEmbeddings( 142 features=features, 143 embedding_dim=embedding_dim, 144 hidden_dim=hidden_dim, 145 ) 146 self.interaction_tokens = InteractionTokens( 147 n_features=self.q, 148 token_dim=token_dim, 149 ) 150 self.shared_net = SharedInteractionNet( 151 embedding_dim=embedding_dim, 152 token_dim=token_dim, 153 layer1_dim=shared_dims[0], 154 layer2_dim=shared_dims[1], 155 ) 156 157 n_pairs = self.q * (self.q + 1) // 2 158 # Output weights w_{jk} and bias b — linear combination of pair terms 159 self.output_weights = nn.Parameter(torch.randn(n_pairs) * 0.01) 160 self.output_bias = nn.Parameter(torch.zeros(1)) 161 162 # Loss function 163 loss_kwargs = {"p": tweedie_p} if loss == "tweedie" else {} 164 self._loss_fn = get_loss(loss, **loss_kwargs) 165 166 # Weight initialisation: smaller scales for numerical stability 167 self._init_weights() 168 169 # Training state 170 self._is_fitted = False 171 self.train_history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []} 172 173 # Centering offsets (set post-hoc after fitting) 174 # h_{jk}^centered(x) = h_{jk}(x) - mean_train[h_{jk}] 175 # We store mean_train[w_{jk} * h_{jk}] per pair for efficiency 176 self._pair_means: Optional[torch.Tensor] = None 177 178 # ------------------------------------------------------------------ 179 # Weight initialisation 180 # ------------------------------------------------------------------ 181 182 def _init_weights(self) -> None: 183 """ 184 Initialise all linear layers to smaller scales for training stability. 185 186 Default PyTorch Kaiming init can produce large initial activations 187 when many layers are composed. With smaller scales, early training 188 is more stable even on very small datasets. 189 """ 190 for module in self.modules(): 191 if isinstance(module, nn.Linear): 192 nn.init.normal_(module.weight, mean=0.0, std=0.01) 193 if module.bias is not None: 194 nn.init.zeros_(module.bias) 195 elif isinstance(module, nn.Embedding): 196 nn.init.normal_(module.weight, mean=0.0, std=0.01) 197 198 # ------------------------------------------------------------------ 199 # Forward pass 200 # ------------------------------------------------------------------ 201 202 def _compute_linear_predictor( 203 self, 204 x_dict: Dict[str, torch.Tensor], 205 apply_centering: bool = True, 206 ) -> torch.Tensor: 207 """ 208 Compute sum_{j<=k} w_{jk} * h_{jk}(x) + b. 209 210 Args: 211 x_dict: Feature tensors on self._device. 212 apply_centering: Subtract pair means (set during fit) for identifiability. 213 214 Returns: 215 Shape (batch_size,). 216 """ 217 # Embed all features once 218 embeddings: Dict[str, torch.Tensor] = self.feature_embeddings.embed_all(x_dict) 219 220 pairs = self.interaction_tokens.pair_indices() 221 terms = [] 222 223 for pair_idx, (j, k) in enumerate(pairs): 224 fname_j = self.feature_names[j] 225 fname_k = self.feature_names[k] 226 227 phi_j = embeddings[fname_j] 228 phi_k = embeddings[fname_k] 229 token = self.interaction_tokens.get_token(j, k) 230 231 raw = self.shared_net(phi_j, phi_k, token) # (batch, 1) 232 h = centered_hard_sigmoid(raw).squeeze(-1) # (batch,) 233 234 w = self.output_weights[pair_idx] 235 term = w * h # (batch,) 236 terms.append(term) 237 238 # Stack to (batch, n_pairs) then sum over pairs 239 all_terms = torch.stack(terms, dim=1) # (batch, n_pairs) 240 241 if apply_centering and self._pair_means is not None: 242 all_terms = all_terms - self._pair_means.unsqueeze(0) 243 244 linear_pred = all_terms.sum(dim=1) + self.output_bias.squeeze() 245 return linear_pred 246 247 def forward( 248 self, 249 x_dict: Dict[str, torch.Tensor], 250 exposure: Optional[torch.Tensor] = None, 251 ) -> torch.Tensor: 252 """ 253 Compute predictions f_PIN(x) = exp(linear_predictor) * exposure. 254 255 When exposure is provided, the raw model output is frequency (claims per 256 year) and multiplying by exposure gives expected claim count. For 257 frequency models, typically you'd pass exposure=None and let the caller 258 multiply; this method supports both modes. 259 260 Args: 261 x_dict: Feature tensors. Dict of feature_name -> (batch,) tensor. 262 exposure: Optional per-sample exposure, shape (batch,). 263 264 Returns: 265 Predicted frequency, shape (batch,). 266 """ 267 eta = self._compute_linear_predictor(x_dict) 268 # Clamp to avoid exp overflow (GLM link stabilisation) 269 eta = torch.clamp(eta, min=-20.0, max=20.0) 270 mu = torch.exp(eta) 271 if exposure is not None: 272 mu = mu * exposure 273 return mu 274 275 # ------------------------------------------------------------------ 276 # Data preparation 277 # ------------------------------------------------------------------ 278 279 def _prepare_features( 280 self, 281 X: Union[Dict, "pl.DataFrame", "pd.DataFrame"], # noqa: F821 282 ) -> Dict[str, torch.Tensor]: 283 """ 284 Convert input data to a dict of tensors. 285 286 Accepts: 287 - Dict[str, np.ndarray] 288 - Dict[str, list] 289 - polars.DataFrame 290 - pandas.DataFrame 291 """ 292 # Try polars first 293 try: 294 import polars as pl 295 if isinstance(X, pl.DataFrame): 296 return self._polars_to_dict(X) 297 except ImportError: 298 pass 299 300 # Try pandas 301 try: 302 import pandas as pd 303 if isinstance(X, pd.DataFrame): 304 return {col: X[col].to_numpy() for col in self.feature_names} 305 except ImportError: 306 pass 307 308 # Assume it's already a dict 309 if isinstance(X, dict): 310 return X 311 312 raise TypeError( 313 f"X must be a dict, polars.DataFrame, or pandas.DataFrame. Got {type(X)}." 314 ) 315 316 def _polars_to_dict(self, df) -> Dict[str, np.ndarray]: 317 result = {} 318 for name in self.feature_names: 319 result[name] = df[name].to_numpy() 320 return result 321 322 def _to_device_dict( 323 self, x_dict: Dict[str, Union[np.ndarray, torch.Tensor]] 324 ) -> Dict[str, torch.Tensor]: 325 """Convert dict of arrays to tensors on device.""" 326 result = {} 327 for name, arr in x_dict.items(): 328 spec = self.features[name] 329 if spec == "continuous": 330 result[name] = _to_tensor(arr).to(self._device) 331 else: 332 result[name] = _to_long_tensor(arr).to(self._device) 333 return result 334 335 # ------------------------------------------------------------------ 336 # Training 337 # ------------------------------------------------------------------ 338 339 def fit( 340 self, 341 X_train, 342 y_train: np.ndarray, 343 exposure: Optional[np.ndarray] = None, 344 X_val=None, 345 y_val: Optional[np.ndarray] = None, 346 exposure_val: Optional[np.ndarray] = None, 347 verbose: bool = True, 348 ) -> "PINModel": 349 """ 350 Fit the model. 351 352 Args: 353 X_train: Training features. Dict, polars.DataFrame, or pandas.DataFrame. 354 y_train: Observed frequency (claims / exposure), shape (n,). 355 exposure: Per-sample exposure (years at risk), shape (n,). 356 X_val: Validation features (optional). If None, 10% of training data 357 is reserved. 358 y_val: Validation targets. 359 exposure_val: Validation exposure. 360 verbose: Print training progress. 361 362 Returns: 363 self (for chaining). 364 """ 365 torch.manual_seed(self.random_seed) 366 np.random.seed(self.random_seed) 367 368 self.to(self._device) 369 370 # Prepare training data 371 x_dict = self._prepare_features(X_train) 372 x_dict = self._to_device_dict(x_dict) 373 y_t = _to_tensor(y_train).to(self._device) 374 exp_t = ( 375 _to_tensor(exposure).to(self._device) 376 if exposure is not None 377 else torch.ones_like(y_t) 378 ) 379 380 n = y_t.shape[0] 381 382 # Build or reserve validation set (must happen before bias init so we 383 # use training-only data for the mean frequency estimate). 384 if X_val is not None: 385 x_val_dict = self._prepare_features(X_val) 386 x_val_dict = self._to_device_dict(x_val_dict) 387 y_val_t = _to_tensor(y_val).to(self._device) 388 exp_val_t = ( 389 _to_tensor(exposure_val).to(self._device) 390 if exposure_val is not None 391 else torch.ones_like(y_val_t) 392 ) 393 else: 394 # Reserve val_fraction from training 395 val_size = max(1, int(n * self.val_fraction)) 396 perm = torch.randperm(n, device=self._device) 397 val_idx = perm[:val_size] 398 train_idx = perm[val_size:] 399 400 x_val_dict = {k: v[val_idx] for k, v in x_dict.items()} 401 y_val_t = y_t[val_idx] 402 exp_val_t = exp_t[val_idx] 403 404 x_dict = {k: v[train_idx] for k, v in x_dict.items()} 405 y_t = y_t[train_idx] 406 exp_t = exp_t[train_idx] 407 408 n_train = y_t.shape[0] 409 410 # Initialise bias to log(mean_frequency) of the TRAINING split only. 411 # Computing this from the full dataset (including validation) would 412 # be a mild form of data leakage — small in practice but wrong in 413 # principle. We compute it here, after the split. 414 with torch.no_grad(): 415 mean_freq = y_t.mean().clamp(min=1e-8) 416 self.output_bias.fill_(torch.log(mean_freq).item()) 417 418 optimizer = optim.Adam(self.parameters(), lr=self.lr) 419 scheduler = optim.lr_scheduler.ReduceLROnPlateau( 420 optimizer, 421 mode="min", 422 factor=self.lr_factor, 423 patience=self.lr_patience, 424 ) 425 426 best_val_loss = float("inf") 427 best_state = None 428 epochs_no_improve = 0 429 430 self.train_history = {"train_loss": [], "val_loss": []} 431 432 for epoch in range(self.max_epochs): 433 self.train() 434 # Shuffle training data 435 perm = torch.randperm(n_train, device=self._device) 436 437 epoch_loss = 0.0 438 n_batches = 0 439 440 for start in range(0, n_train, self.batch_size): 441 end = min(start + self.batch_size, n_train) 442 batch_idx = perm[start:end] 443 444 x_batch = {k: v[batch_idx] for k, v in x_dict.items()} 445 y_batch = y_t[batch_idx] 446 exp_batch = exp_t[batch_idx] 447 448 optimizer.zero_grad() 449 mu = self.forward(x_batch) 450 loss = self._loss_fn(mu, y_batch, exp_batch) 451 if torch.isnan(loss): 452 continue # skip NaN batches 453 loss.backward() 454 # Gradient clipping prevents exploding gradients in early training 455 torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) 456 # Zero NaN gradients before step 457 for p in self.parameters(): 458 if p.grad is not None and torch.isnan(p.grad).any(): 459 p.grad.zero_() 460 optimizer.step() 461 # Restore any NaN params to their pre-step values 462 for p in self.parameters(): 463 if torch.isnan(p).any(): 464 torch.nan_to_num_(p, nan=0.0) 465 466 epoch_loss += loss.item() 467 n_batches += 1 468 469 train_loss = epoch_loss / n_batches 470 471 # Validation 472 self.eval() 473 with torch.no_grad(): 474 mu_val = self.forward(x_val_dict) 475 val_loss = self._loss_fn(mu_val, y_val_t, exp_val_t).item() 476 477 scheduler.step(val_loss) 478 479 self.train_history["train_loss"].append(train_loss) 480 self.train_history["val_loss"].append(val_loss) 481 482 if verbose and (epoch % 50 == 0 or epoch < 10): 483 lr_now = optimizer.param_groups[0]["lr"] 484 print( 485 f"Epoch {epoch:4d} | train={train_loss:.6f} | val={val_loss:.6f} | lr={lr_now:.2e}" 486 ) 487 488 # Early stopping 489 if val_loss < best_val_loss: 490 best_val_loss = val_loss 491 best_state = copy.deepcopy(self.state_dict()) 492 epochs_no_improve = 0 493 else: 494 epochs_no_improve += 1 495 if epochs_no_improve >= self.patience: 496 if verbose: 497 print(f"Early stopping at epoch {epoch}. Best val={best_val_loss:.6f}") 498 break 499 500 # Restore best weights 501 if best_state is not None: 502 self.load_state_dict(best_state) 503 504 # Post-hoc centering: compute pair means on training data 505 self._compute_pair_centering(x_dict) 506 507 self._is_fitted = True 508 return self 509 510 def _compute_pair_centering( 511 self, x_dict: Dict[str, torch.Tensor], batch_size: int = 1024 512 ) -> None: 513 """ 514 Compute mean of w_{jk} * h_{jk}(x) over training data for each pair. 515 516 Centering ensures that the bias b absorbs the intercept and the 517 interaction terms have zero mean — making w_{jk} interpretable as 518 the amplitude of the interaction effect. 519 520 Stores result in self._pair_means, shape (n_pairs,). 521 """ 522 self.eval() 523 n = x_dict[self.feature_names[0]].shape[0] 524 n_pairs = self.q * (self.q + 1) // 2 525 pairs = self.interaction_tokens.pair_indices() 526 527 pair_sums = torch.zeros(n_pairs, device=self._device) 528 n_seen = 0 529 530 with torch.no_grad(): 531 for start in range(0, n, batch_size): 532 end = min(start + batch_size, n) 533 x_batch = {k: v[start:end] for k, v in x_dict.items()} 534 embeddings = self.feature_embeddings.embed_all(x_batch) 535 batch_n = end - start 536 537 for pair_idx, (j, k) in enumerate(pairs): 538 fname_j = self.feature_names[j] 539 fname_k = self.feature_names[k] 540 phi_j = embeddings[fname_j] 541 phi_k = embeddings[fname_k] 542 token = self.interaction_tokens.get_token(j, k) 543 raw = self.shared_net(phi_j, phi_k, token) 544 h = centered_hard_sigmoid(raw).squeeze(-1) 545 w = self.output_weights[pair_idx] 546 pair_sums[pair_idx] += (w * h).sum() 547 548 n_seen += batch_n 549 550 self._pair_means = pair_sums / n_seen 551 552 # ------------------------------------------------------------------ 553 # Prediction 554 # ------------------------------------------------------------------ 555 556 def predict( 557 self, 558 X, 559 exposure: Optional[np.ndarray] = None, 560 ) -> np.ndarray: 561 """ 562 Predict frequency (or expected claims if exposure given). 563 564 Args: 565 X: Features. Dict, polars.DataFrame, or pandas.DataFrame. 566 exposure: Per-sample exposure. 567 568 Returns: 569 Predictions as numpy array, shape (n,). 570 """ 571 if not self._is_fitted: 572 raise RuntimeError("Call fit() before predict().") 573 574 self.eval() 575 x_dict = self._prepare_features(X) 576 x_dict = self._to_device_dict(x_dict) 577 578 exp_t = ( 579 _to_tensor(exposure).to(self._device) 580 if exposure is not None 581 else None 582 ) 583 584 with torch.no_grad(): 585 mu = self.forward(x_dict, exposure=exp_t) 586 587 return mu.cpu().numpy() 588 589 # ------------------------------------------------------------------ 590 # Interpretability surfaces 591 # ------------------------------------------------------------------ 592 593 def pair_contributions( 594 self, 595 X, 596 ) -> Dict[Tuple[str, str], np.ndarray]: 597 """ 598 Compute w_{jk} * h_{jk}(x) for each pair and each sample. 599 600 These are the additive components on the linear predictor scale. 601 Useful for understanding which pairs drive predictions. 602 603 Args: 604 X: Features. 605 606 Returns: 607 Dict mapping (fname_j, fname_k) -> array of shape (n,). 608 """ 609 if not self._is_fitted: 610 raise RuntimeError("Call fit() before pair_contributions().") 611 612 self.eval() 613 x_dict = self._prepare_features(X) 614 x_dict = self._to_device_dict(x_dict) 615 pairs = self.interaction_tokens.pair_indices() 616 result = {} 617 618 with torch.no_grad(): 619 embeddings = self.feature_embeddings.embed_all(x_dict) 620 621 for pair_idx, (j, k) in enumerate(pairs): 622 fname_j = self.feature_names[j] 623 fname_k = self.feature_names[k] 624 phi_j = embeddings[fname_j] 625 phi_k = embeddings[fname_k] 626 token = self.interaction_tokens.get_token(j, k) 627 raw = self.shared_net(phi_j, phi_k, token) 628 h = centered_hard_sigmoid(raw).squeeze(-1) 629 w = self.output_weights[pair_idx] 630 contrib = (w * h) 631 if self._pair_means is not None: 632 contrib = contrib - self._pair_means[pair_idx] 633 result[(fname_j, fname_k)] = contrib.cpu().numpy() 634 635 return result 636 637 def main_effects( 638 self, 639 X_background, 640 n_grid: int = 100, 641 ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: 642 """ 643 Compute main effect curves for each feature. 644 645 For continuous features: evaluates h_{jj} over an evenly-spaced grid 646 while fixing all other features to background means. 647 648 Returns: 649 Dict mapping feature_name -> (grid_values, effect_values). 650 For categoricals: grid_values are integer category codes. 651 """ 652 if not self._is_fitted: 653 raise RuntimeError("Call fit() before main_effects().") 654 655 self.eval() 656 bg_dict = self._prepare_features(X_background) 657 bg_dict = self._to_device_dict(bg_dict) 658 659 result = {} 660 661 with torch.no_grad(): 662 for i, fname in enumerate(self.feature_names): 663 spec = self.features[fname] 664 665 if spec == "continuous": 666 vals = bg_dict[fname].float() 667 lo, hi = vals.min().item(), vals.max().item() 668 grid = torch.linspace(lo, hi, n_grid, device=self._device) 669 elif isinstance(spec, int): 670 grid = torch.arange(spec, device=self._device) 671 else: 672 continue 673 674 # Build batch: vary feature i, fix others to background mean 675 n_grid_pts = grid.shape[0] 676 x_eval = {} 677 for fname2 in self.feature_names: 678 spec2 = self.features[fname2] 679 if fname2 == fname: 680 x_eval[fname2] = grid.long() if isinstance(spec, int) else grid 681 else: 682 if spec2 == "continuous": 683 mean_val = bg_dict[fname2].float().mean() 684 x_eval[fname2] = mean_val.expand(n_grid_pts) 685 else: 686 mode_val = bg_dict[fname2].long().mode().values 687 x_eval[fname2] = mode_val.expand(n_grid_pts) 688 689 # Get diagonal pair contribution (j=i, k=i) 690 phi_j = self.feature_embeddings.embed_feature(fname, x_eval[fname]) 691 token = self.interaction_tokens.get_token(i, i) 692 raw = self.shared_net(phi_j, phi_j, token) 693 h = centered_hard_sigmoid(raw).squeeze(-1) 694 pair_idx = self.interaction_tokens._pair_to_idx[(i, i)] 695 w = self.output_weights[pair_idx] 696 effect = (w * h).cpu().numpy() 697 698 result[fname] = (grid.cpu().numpy(), effect) 699 700 return result 701 702 def interaction_surfaces( 703 self, 704 X_background, 705 n_grid: int = 30, 706 pairs: Optional[List[Tuple[str, str]]] = None, 707 ) -> Dict[Tuple[str, str], Dict]: 708 """ 709 Compute 2D interaction surfaces for feature pairs. 710 711 For pair (j, k): evaluates w_{jk} * h_{jk}(x_j, x_k) over a 2D grid. 712 Main effect contributions (j=j and k=k pairs) are NOT subtracted here — 713 this shows the raw interaction term. Use pair_contributions() for 714 full decomposition. 715 716 Args: 717 X_background: Background data for range estimation. 718 n_grid: Grid resolution per axis (n_grid x n_grid for continuous). 719 pairs: List of (feature_j, feature_k) pairs to compute. If None, 720 computes all off-diagonal interaction pairs. 721 722 Returns: 723 Dict mapping (fname_j, fname_k) -> { 724 'grid_j': array of shape (n_grid,), 725 'grid_k': array of shape (n_grid,) or (n_cats,), 726 'surface': array of shape (n_grid_j, n_grid_k), 727 } 728 """ 729 if not self._is_fitted: 730 raise RuntimeError("Call fit() before interaction_surfaces().") 731 732 self.eval() 733 bg_dict = self._prepare_features(X_background) 734 bg_dict = self._to_device_dict(bg_dict) 735 736 # All off-diagonal pairs if not specified 737 if pairs is None: 738 all_pairs = self.interaction_tokens.pair_indices() 739 pairs = [ 740 (self.feature_names[j], self.feature_names[k]) 741 for j, k in all_pairs 742 if j != k 743 ] 744 745 result = {} 746 747 with torch.no_grad(): 748 for (fname_j, fname_k) in pairs: 749 j = self.feature_names.index(fname_j) 750 k = self.feature_names.index(fname_k) 751 752 spec_j = self.features[fname_j] 753 spec_k = self.features[fname_k] 754 755 if spec_j == "continuous": 756 vals_j = bg_dict[fname_j].float() 757 grid_j = torch.linspace(vals_j.min(), vals_j.max(), n_grid, device=self._device) 758 else: 759 grid_j = torch.arange(spec_j, device=self._device) 760 761 if spec_k == "continuous": 762 vals_k = bg_dict[fname_k].float() 763 grid_k = torch.linspace(vals_k.min(), vals_k.max(), n_grid, device=self._device) 764 else: 765 grid_k = torch.arange(spec_k, device=self._device) 766 767 nj, nk = grid_j.shape[0], grid_k.shape[0] 768 769 # Meshgrid 770 gj = grid_j.repeat_interleave(nk) # (nj * nk,) 771 gk = grid_k.repeat(nj) # (nj * nk,) 772 773 phi_j = self.feature_embeddings.embed_feature( 774 fname_j, gj.long() if isinstance(spec_j, int) else gj 775 ) 776 phi_k = self.feature_embeddings.embed_feature( 777 fname_k, gk.long() if isinstance(spec_k, int) else gk 778 ) 779 780 token = self.interaction_tokens.get_token(j, k) 781 raw = self.shared_net(phi_j, phi_k, token) 782 h = centered_hard_sigmoid(raw).squeeze(-1) 783 pair_idx = self.interaction_tokens._pair_to_idx[(j, k)] 784 w = self.output_weights[pair_idx] 785 surface = (w * h).reshape(nj, nk).cpu().numpy() 786 787 result[(fname_j, fname_k)] = { 788 "grid_j": grid_j.cpu().numpy(), 789 "grid_k": grid_k.cpu().numpy(), 790 "surface": surface, 791 } 792 793 return result 794 795 def shapley_values( 796 self, 797 X, 798 X_background, 799 n_background: int = 100, 800 ) -> Dict[str, np.ndarray]: 801 """ 802 Compute exact Shapley values using the pairwise additive structure. 803 804 Cost: 2(q+1) forward passes per test sample per background sample. 805 For q=9, n_background=100: ~2000 forward passes per test sample. 806 This is exact — no sampling approximation. 807 808 Args: 809 X: Test data. 810 X_background: Background data for baseline distribution. 811 n_background: Number of background samples to use. 812 813 Returns: 814 Dict mapping feature_name -> (n_test,) array of SHAP values. 815 Values are on the linear predictor scale. 816 """ 817 from insurance_gam.pin.shapley import exact_shapley_values 818 819 if not self._is_fitted: 820 raise RuntimeError("Call fit() before shapley_values().") 821 822 x_dict = self._prepare_features(X) 823 x_dict = self._to_device_dict(x_dict) 824 bg_dict = self._prepare_features(X_background) 825 bg_dict = self._to_device_dict(bg_dict) 826 827 return exact_shapley_values(self, x_dict, bg_dict, n_background) 828 829 def interaction_weights(self) -> Dict[Tuple[str, str], float]: 830 """ 831 Return output weights w_{jk} for all pairs as a dict. 832 833 Large absolute weights indicate important interactions or main effects. 834 Note: weights are on the scale of the linear predictor; compare 835 w_{jk} * h_{jk} range rather than w_{jk} alone for fair comparison. 836 837 Returns: 838 Dict mapping (fname_j, fname_k) -> float. 839 """ 840 pairs = self.interaction_tokens.pair_indices() 841 weights = self.output_weights.detach().cpu().numpy() 842 return { 843 (self.feature_names[j], self.feature_names[k]): float(weights[idx]) 844 for idx, (j, k) in enumerate(pairs) 845 } 846 847 def count_parameters(self) -> int: 848 """Count trainable parameters.""" 849 return sum(p.numel() for p in self.parameters() if p.requires_grad)
Single Tree-like Pairwise Interaction Network.
Prediction:
f_PIN(x) = exp( sum_{j<=k} w_{jk} * h_{jk}(x) + b )
Arguments:
- features: Dict mapping feature name to spec. Spec is 'continuous' or an int (number of categories). Order matters — features are indexed by position.
- embedding_dim: Feature embedding dimension d (default 10).
- hidden_dim: Hidden width for continuous embedding FNNs d' (default 20).
- token_dim: Interaction token dimension d0 (default 10).
- shared_dims: (d1, d2) widths for shared interaction network (default [30, 20]).
- loss: Loss name — 'poisson', 'gamma', or 'tweedie'.
- tweedie_p: Tweedie power (only used when loss='tweedie').
- lr: Adam learning rate (default 0.001).
- batch_size: Mini-batch size (default 128).
- max_epochs: Maximum training epochs (default 500).
- patience: Early stopping patience in epochs (default 20).
- lr_patience: ReduceLROnPlateau patience (default 5).
- lr_factor: ReduceLROnPlateau reduction factor (default 0.9).
- val_fraction: Fraction of training data for validation if X_val not given (default 0.1).
- device: Torch device string, or None to auto-detect.
- random_seed: Seed for reproducibility.
Examples:
>>> model = PINModel( ... features={"age": "continuous", "area": 5}, ... loss="poisson", ... )
95 def __init__( 96 self, 97 features: FeatureSpec, 98 embedding_dim: int = 10, 99 hidden_dim: int = 20, 100 token_dim: int = 10, 101 shared_dims: Tuple[int, int] = (30, 20), 102 loss: str = "poisson", 103 tweedie_p: float = 1.5, 104 lr: float = 0.001, 105 batch_size: int = 128, 106 max_epochs: int = 500, 107 patience: int = 20, 108 lr_patience: int = 5, 109 lr_factor: float = 0.9, 110 val_fraction: float = 0.1, 111 device: Optional[str] = None, 112 random_seed: int = 42, 113 ) -> None: 114 super().__init__() 115 116 self.features = dict(features) 117 self.feature_names: List[str] = list(features.keys()) 118 self.q = len(self.feature_names) 119 self.embedding_dim = embedding_dim 120 self.hidden_dim = hidden_dim 121 self.token_dim = token_dim 122 self.shared_dims = tuple(shared_dims) 123 self.loss_name = loss 124 self.tweedie_p = tweedie_p 125 self.lr = lr 126 self.batch_size = batch_size 127 self.max_epochs = max_epochs 128 self.patience = patience 129 self.lr_patience = lr_patience 130 self.lr_factor = lr_factor 131 self.val_fraction = val_fraction 132 self.random_seed = random_seed 133 134 # Resolve device 135 if device is None: 136 self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 137 else: 138 self._device = torch.device(device) 139 140 # --- Sub-modules --- 141 self.feature_embeddings = FeatureEmbeddings( 142 features=features, 143 embedding_dim=embedding_dim, 144 hidden_dim=hidden_dim, 145 ) 146 self.interaction_tokens = InteractionTokens( 147 n_features=self.q, 148 token_dim=token_dim, 149 ) 150 self.shared_net = SharedInteractionNet( 151 embedding_dim=embedding_dim, 152 token_dim=token_dim, 153 layer1_dim=shared_dims[0], 154 layer2_dim=shared_dims[1], 155 ) 156 157 n_pairs = self.q * (self.q + 1) // 2 158 # Output weights w_{jk} and bias b — linear combination of pair terms 159 self.output_weights = nn.Parameter(torch.randn(n_pairs) * 0.01) 160 self.output_bias = nn.Parameter(torch.zeros(1)) 161 162 # Loss function 163 loss_kwargs = {"p": tweedie_p} if loss == "tweedie" else {} 164 self._loss_fn = get_loss(loss, **loss_kwargs) 165 166 # Weight initialisation: smaller scales for numerical stability 167 self._init_weights() 168 169 # Training state 170 self._is_fitted = False 171 self.train_history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []} 172 173 # Centering offsets (set post-hoc after fitting) 174 # h_{jk}^centered(x) = h_{jk}(x) - mean_train[h_{jk}] 175 # We store mean_train[w_{jk} * h_{jk}] per pair for efficiency 176 self._pair_means: Optional[torch.Tensor] = None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
247 def forward( 248 self, 249 x_dict: Dict[str, torch.Tensor], 250 exposure: Optional[torch.Tensor] = None, 251 ) -> torch.Tensor: 252 """ 253 Compute predictions f_PIN(x) = exp(linear_predictor) * exposure. 254 255 When exposure is provided, the raw model output is frequency (claims per 256 year) and multiplying by exposure gives expected claim count. For 257 frequency models, typically you'd pass exposure=None and let the caller 258 multiply; this method supports both modes. 259 260 Args: 261 x_dict: Feature tensors. Dict of feature_name -> (batch,) tensor. 262 exposure: Optional per-sample exposure, shape (batch,). 263 264 Returns: 265 Predicted frequency, shape (batch,). 266 """ 267 eta = self._compute_linear_predictor(x_dict) 268 # Clamp to avoid exp overflow (GLM link stabilisation) 269 eta = torch.clamp(eta, min=-20.0, max=20.0) 270 mu = torch.exp(eta) 271 if exposure is not None: 272 mu = mu * exposure 273 return mu
Compute predictions f_PIN(x) = exp(linear_predictor) * exposure.
When exposure is provided, the raw model output is frequency (claims per year) and multiplying by exposure gives expected claim count. For frequency models, typically you'd pass exposure=None and let the caller multiply; this method supports both modes.
Arguments:
- x_dict: Feature tensors. Dict of feature_name -> (batch,) tensor.
- exposure: Optional per-sample exposure, shape (batch,).
Returns:
Predicted frequency, shape (batch,).
339 def fit( 340 self, 341 X_train, 342 y_train: np.ndarray, 343 exposure: Optional[np.ndarray] = None, 344 X_val=None, 345 y_val: Optional[np.ndarray] = None, 346 exposure_val: Optional[np.ndarray] = None, 347 verbose: bool = True, 348 ) -> "PINModel": 349 """ 350 Fit the model. 351 352 Args: 353 X_train: Training features. Dict, polars.DataFrame, or pandas.DataFrame. 354 y_train: Observed frequency (claims / exposure), shape (n,). 355 exposure: Per-sample exposure (years at risk), shape (n,). 356 X_val: Validation features (optional). If None, 10% of training data 357 is reserved. 358 y_val: Validation targets. 359 exposure_val: Validation exposure. 360 verbose: Print training progress. 361 362 Returns: 363 self (for chaining). 364 """ 365 torch.manual_seed(self.random_seed) 366 np.random.seed(self.random_seed) 367 368 self.to(self._device) 369 370 # Prepare training data 371 x_dict = self._prepare_features(X_train) 372 x_dict = self._to_device_dict(x_dict) 373 y_t = _to_tensor(y_train).to(self._device) 374 exp_t = ( 375 _to_tensor(exposure).to(self._device) 376 if exposure is not None 377 else torch.ones_like(y_t) 378 ) 379 380 n = y_t.shape[0] 381 382 # Build or reserve validation set (must happen before bias init so we 383 # use training-only data for the mean frequency estimate). 384 if X_val is not None: 385 x_val_dict = self._prepare_features(X_val) 386 x_val_dict = self._to_device_dict(x_val_dict) 387 y_val_t = _to_tensor(y_val).to(self._device) 388 exp_val_t = ( 389 _to_tensor(exposure_val).to(self._device) 390 if exposure_val is not None 391 else torch.ones_like(y_val_t) 392 ) 393 else: 394 # Reserve val_fraction from training 395 val_size = max(1, int(n * self.val_fraction)) 396 perm = torch.randperm(n, device=self._device) 397 val_idx = perm[:val_size] 398 train_idx = perm[val_size:] 399 400 x_val_dict = {k: v[val_idx] for k, v in x_dict.items()} 401 y_val_t = y_t[val_idx] 402 exp_val_t = exp_t[val_idx] 403 404 x_dict = {k: v[train_idx] for k, v in x_dict.items()} 405 y_t = y_t[train_idx] 406 exp_t = exp_t[train_idx] 407 408 n_train = y_t.shape[0] 409 410 # Initialise bias to log(mean_frequency) of the TRAINING split only. 411 # Computing this from the full dataset (including validation) would 412 # be a mild form of data leakage — small in practice but wrong in 413 # principle. We compute it here, after the split. 414 with torch.no_grad(): 415 mean_freq = y_t.mean().clamp(min=1e-8) 416 self.output_bias.fill_(torch.log(mean_freq).item()) 417 418 optimizer = optim.Adam(self.parameters(), lr=self.lr) 419 scheduler = optim.lr_scheduler.ReduceLROnPlateau( 420 optimizer, 421 mode="min", 422 factor=self.lr_factor, 423 patience=self.lr_patience, 424 ) 425 426 best_val_loss = float("inf") 427 best_state = None 428 epochs_no_improve = 0 429 430 self.train_history = {"train_loss": [], "val_loss": []} 431 432 for epoch in range(self.max_epochs): 433 self.train() 434 # Shuffle training data 435 perm = torch.randperm(n_train, device=self._device) 436 437 epoch_loss = 0.0 438 n_batches = 0 439 440 for start in range(0, n_train, self.batch_size): 441 end = min(start + self.batch_size, n_train) 442 batch_idx = perm[start:end] 443 444 x_batch = {k: v[batch_idx] for k, v in x_dict.items()} 445 y_batch = y_t[batch_idx] 446 exp_batch = exp_t[batch_idx] 447 448 optimizer.zero_grad() 449 mu = self.forward(x_batch) 450 loss = self._loss_fn(mu, y_batch, exp_batch) 451 if torch.isnan(loss): 452 continue # skip NaN batches 453 loss.backward() 454 # Gradient clipping prevents exploding gradients in early training 455 torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) 456 # Zero NaN gradients before step 457 for p in self.parameters(): 458 if p.grad is not None and torch.isnan(p.grad).any(): 459 p.grad.zero_() 460 optimizer.step() 461 # Restore any NaN params to their pre-step values 462 for p in self.parameters(): 463 if torch.isnan(p).any(): 464 torch.nan_to_num_(p, nan=0.0) 465 466 epoch_loss += loss.item() 467 n_batches += 1 468 469 train_loss = epoch_loss / n_batches 470 471 # Validation 472 self.eval() 473 with torch.no_grad(): 474 mu_val = self.forward(x_val_dict) 475 val_loss = self._loss_fn(mu_val, y_val_t, exp_val_t).item() 476 477 scheduler.step(val_loss) 478 479 self.train_history["train_loss"].append(train_loss) 480 self.train_history["val_loss"].append(val_loss) 481 482 if verbose and (epoch % 50 == 0 or epoch < 10): 483 lr_now = optimizer.param_groups[0]["lr"] 484 print( 485 f"Epoch {epoch:4d} | train={train_loss:.6f} | val={val_loss:.6f} | lr={lr_now:.2e}" 486 ) 487 488 # Early stopping 489 if val_loss < best_val_loss: 490 best_val_loss = val_loss 491 best_state = copy.deepcopy(self.state_dict()) 492 epochs_no_improve = 0 493 else: 494 epochs_no_improve += 1 495 if epochs_no_improve >= self.patience: 496 if verbose: 497 print(f"Early stopping at epoch {epoch}. Best val={best_val_loss:.6f}") 498 break 499 500 # Restore best weights 501 if best_state is not None: 502 self.load_state_dict(best_state) 503 504 # Post-hoc centering: compute pair means on training data 505 self._compute_pair_centering(x_dict) 506 507 self._is_fitted = True 508 return self
Fit the model.
Arguments:
- X_train: Training features. Dict, polars.DataFrame, or pandas.DataFrame.
- y_train: Observed frequency (claims / exposure), shape (n,).
- exposure: Per-sample exposure (years at risk), shape (n,).
- X_val: Validation features (optional). If None, 10% of training data is reserved.
- y_val: Validation targets.
- exposure_val: Validation exposure.
- verbose: Print training progress.
Returns:
self (for chaining).
556 def predict( 557 self, 558 X, 559 exposure: Optional[np.ndarray] = None, 560 ) -> np.ndarray: 561 """ 562 Predict frequency (or expected claims if exposure given). 563 564 Args: 565 X: Features. Dict, polars.DataFrame, or pandas.DataFrame. 566 exposure: Per-sample exposure. 567 568 Returns: 569 Predictions as numpy array, shape (n,). 570 """ 571 if not self._is_fitted: 572 raise RuntimeError("Call fit() before predict().") 573 574 self.eval() 575 x_dict = self._prepare_features(X) 576 x_dict = self._to_device_dict(x_dict) 577 578 exp_t = ( 579 _to_tensor(exposure).to(self._device) 580 if exposure is not None 581 else None 582 ) 583 584 with torch.no_grad(): 585 mu = self.forward(x_dict, exposure=exp_t) 586 587 return mu.cpu().numpy()
Predict frequency (or expected claims if exposure given).
Arguments:
- X: Features. Dict, polars.DataFrame, or pandas.DataFrame.
- exposure: Per-sample exposure.
Returns:
Predictions as numpy array, shape (n,).
593 def pair_contributions( 594 self, 595 X, 596 ) -> Dict[Tuple[str, str], np.ndarray]: 597 """ 598 Compute w_{jk} * h_{jk}(x) for each pair and each sample. 599 600 These are the additive components on the linear predictor scale. 601 Useful for understanding which pairs drive predictions. 602 603 Args: 604 X: Features. 605 606 Returns: 607 Dict mapping (fname_j, fname_k) -> array of shape (n,). 608 """ 609 if not self._is_fitted: 610 raise RuntimeError("Call fit() before pair_contributions().") 611 612 self.eval() 613 x_dict = self._prepare_features(X) 614 x_dict = self._to_device_dict(x_dict) 615 pairs = self.interaction_tokens.pair_indices() 616 result = {} 617 618 with torch.no_grad(): 619 embeddings = self.feature_embeddings.embed_all(x_dict) 620 621 for pair_idx, (j, k) in enumerate(pairs): 622 fname_j = self.feature_names[j] 623 fname_k = self.feature_names[k] 624 phi_j = embeddings[fname_j] 625 phi_k = embeddings[fname_k] 626 token = self.interaction_tokens.get_token(j, k) 627 raw = self.shared_net(phi_j, phi_k, token) 628 h = centered_hard_sigmoid(raw).squeeze(-1) 629 w = self.output_weights[pair_idx] 630 contrib = (w * h) 631 if self._pair_means is not None: 632 contrib = contrib - self._pair_means[pair_idx] 633 result[(fname_j, fname_k)] = contrib.cpu().numpy() 634 635 return result
Compute w_{jk} * h_{jk}(x) for each pair and each sample.
These are the additive components on the linear predictor scale. Useful for understanding which pairs drive predictions.
Arguments:
- X: Features.
Returns:
Dict mapping (fname_j, fname_k) -> array of shape (n,).
637 def main_effects( 638 self, 639 X_background, 640 n_grid: int = 100, 641 ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: 642 """ 643 Compute main effect curves for each feature. 644 645 For continuous features: evaluates h_{jj} over an evenly-spaced grid 646 while fixing all other features to background means. 647 648 Returns: 649 Dict mapping feature_name -> (grid_values, effect_values). 650 For categoricals: grid_values are integer category codes. 651 """ 652 if not self._is_fitted: 653 raise RuntimeError("Call fit() before main_effects().") 654 655 self.eval() 656 bg_dict = self._prepare_features(X_background) 657 bg_dict = self._to_device_dict(bg_dict) 658 659 result = {} 660 661 with torch.no_grad(): 662 for i, fname in enumerate(self.feature_names): 663 spec = self.features[fname] 664 665 if spec == "continuous": 666 vals = bg_dict[fname].float() 667 lo, hi = vals.min().item(), vals.max().item() 668 grid = torch.linspace(lo, hi, n_grid, device=self._device) 669 elif isinstance(spec, int): 670 grid = torch.arange(spec, device=self._device) 671 else: 672 continue 673 674 # Build batch: vary feature i, fix others to background mean 675 n_grid_pts = grid.shape[0] 676 x_eval = {} 677 for fname2 in self.feature_names: 678 spec2 = self.features[fname2] 679 if fname2 == fname: 680 x_eval[fname2] = grid.long() if isinstance(spec, int) else grid 681 else: 682 if spec2 == "continuous": 683 mean_val = bg_dict[fname2].float().mean() 684 x_eval[fname2] = mean_val.expand(n_grid_pts) 685 else: 686 mode_val = bg_dict[fname2].long().mode().values 687 x_eval[fname2] = mode_val.expand(n_grid_pts) 688 689 # Get diagonal pair contribution (j=i, k=i) 690 phi_j = self.feature_embeddings.embed_feature(fname, x_eval[fname]) 691 token = self.interaction_tokens.get_token(i, i) 692 raw = self.shared_net(phi_j, phi_j, token) 693 h = centered_hard_sigmoid(raw).squeeze(-1) 694 pair_idx = self.interaction_tokens._pair_to_idx[(i, i)] 695 w = self.output_weights[pair_idx] 696 effect = (w * h).cpu().numpy() 697 698 result[fname] = (grid.cpu().numpy(), effect) 699 700 return result
Compute main effect curves for each feature.
For continuous features: evaluates h_{jj} over an evenly-spaced grid while fixing all other features to background means.
Returns:
Dict mapping feature_name -> (grid_values, effect_values). For categoricals: grid_values are integer category codes.
702 def interaction_surfaces( 703 self, 704 X_background, 705 n_grid: int = 30, 706 pairs: Optional[List[Tuple[str, str]]] = None, 707 ) -> Dict[Tuple[str, str], Dict]: 708 """ 709 Compute 2D interaction surfaces for feature pairs. 710 711 For pair (j, k): evaluates w_{jk} * h_{jk}(x_j, x_k) over a 2D grid. 712 Main effect contributions (j=j and k=k pairs) are NOT subtracted here — 713 this shows the raw interaction term. Use pair_contributions() for 714 full decomposition. 715 716 Args: 717 X_background: Background data for range estimation. 718 n_grid: Grid resolution per axis (n_grid x n_grid for continuous). 719 pairs: List of (feature_j, feature_k) pairs to compute. If None, 720 computes all off-diagonal interaction pairs. 721 722 Returns: 723 Dict mapping (fname_j, fname_k) -> { 724 'grid_j': array of shape (n_grid,), 725 'grid_k': array of shape (n_grid,) or (n_cats,), 726 'surface': array of shape (n_grid_j, n_grid_k), 727 } 728 """ 729 if not self._is_fitted: 730 raise RuntimeError("Call fit() before interaction_surfaces().") 731 732 self.eval() 733 bg_dict = self._prepare_features(X_background) 734 bg_dict = self._to_device_dict(bg_dict) 735 736 # All off-diagonal pairs if not specified 737 if pairs is None: 738 all_pairs = self.interaction_tokens.pair_indices() 739 pairs = [ 740 (self.feature_names[j], self.feature_names[k]) 741 for j, k in all_pairs 742 if j != k 743 ] 744 745 result = {} 746 747 with torch.no_grad(): 748 for (fname_j, fname_k) in pairs: 749 j = self.feature_names.index(fname_j) 750 k = self.feature_names.index(fname_k) 751 752 spec_j = self.features[fname_j] 753 spec_k = self.features[fname_k] 754 755 if spec_j == "continuous": 756 vals_j = bg_dict[fname_j].float() 757 grid_j = torch.linspace(vals_j.min(), vals_j.max(), n_grid, device=self._device) 758 else: 759 grid_j = torch.arange(spec_j, device=self._device) 760 761 if spec_k == "continuous": 762 vals_k = bg_dict[fname_k].float() 763 grid_k = torch.linspace(vals_k.min(), vals_k.max(), n_grid, device=self._device) 764 else: 765 grid_k = torch.arange(spec_k, device=self._device) 766 767 nj, nk = grid_j.shape[0], grid_k.shape[0] 768 769 # Meshgrid 770 gj = grid_j.repeat_interleave(nk) # (nj * nk,) 771 gk = grid_k.repeat(nj) # (nj * nk,) 772 773 phi_j = self.feature_embeddings.embed_feature( 774 fname_j, gj.long() if isinstance(spec_j, int) else gj 775 ) 776 phi_k = self.feature_embeddings.embed_feature( 777 fname_k, gk.long() if isinstance(spec_k, int) else gk 778 ) 779 780 token = self.interaction_tokens.get_token(j, k) 781 raw = self.shared_net(phi_j, phi_k, token) 782 h = centered_hard_sigmoid(raw).squeeze(-1) 783 pair_idx = self.interaction_tokens._pair_to_idx[(j, k)] 784 w = self.output_weights[pair_idx] 785 surface = (w * h).reshape(nj, nk).cpu().numpy() 786 787 result[(fname_j, fname_k)] = { 788 "grid_j": grid_j.cpu().numpy(), 789 "grid_k": grid_k.cpu().numpy(), 790 "surface": surface, 791 } 792 793 return result
Compute 2D interaction surfaces for feature pairs.
For pair (j, k): evaluates w_{jk} * h_{jk}(x_j, x_k) over a 2D grid. Main effect contributions (j=j and k=k pairs) are NOT subtracted here — this shows the raw interaction term. Use pair_contributions() for full decomposition.
Arguments:
- X_background: Background data for range estimation.
- n_grid: Grid resolution per axis (n_grid x n_grid for continuous).
- pairs: List of (feature_j, feature_k) pairs to compute. If None, computes all off-diagonal interaction pairs.
Returns:
Dict mapping (fname_j, fname_k) -> { 'grid_j': array of shape (n_grid,), 'grid_k': array of shape (n_grid,) or (n_cats,), 'surface': array of shape (n_grid_j, n_grid_k), }
795 def shapley_values( 796 self, 797 X, 798 X_background, 799 n_background: int = 100, 800 ) -> Dict[str, np.ndarray]: 801 """ 802 Compute exact Shapley values using the pairwise additive structure. 803 804 Cost: 2(q+1) forward passes per test sample per background sample. 805 For q=9, n_background=100: ~2000 forward passes per test sample. 806 This is exact — no sampling approximation. 807 808 Args: 809 X: Test data. 810 X_background: Background data for baseline distribution. 811 n_background: Number of background samples to use. 812 813 Returns: 814 Dict mapping feature_name -> (n_test,) array of SHAP values. 815 Values are on the linear predictor scale. 816 """ 817 from insurance_gam.pin.shapley import exact_shapley_values 818 819 if not self._is_fitted: 820 raise RuntimeError("Call fit() before shapley_values().") 821 822 x_dict = self._prepare_features(X) 823 x_dict = self._to_device_dict(x_dict) 824 bg_dict = self._prepare_features(X_background) 825 bg_dict = self._to_device_dict(bg_dict) 826 827 return exact_shapley_values(self, x_dict, bg_dict, n_background)
Compute exact Shapley values using the pairwise additive structure.
Cost: 2(q+1) forward passes per test sample per background sample. For q=9, n_background=100: ~2000 forward passes per test sample. This is exact — no sampling approximation.
Arguments:
- X: Test data.
- X_background: Background data for baseline distribution.
- n_background: Number of background samples to use.
Returns:
Dict mapping feature_name -> (n_test,) array of SHAP values. Values are on the linear predictor scale.
829 def interaction_weights(self) -> Dict[Tuple[str, str], float]: 830 """ 831 Return output weights w_{jk} for all pairs as a dict. 832 833 Large absolute weights indicate important interactions or main effects. 834 Note: weights are on the scale of the linear predictor; compare 835 w_{jk} * h_{jk} range rather than w_{jk} alone for fair comparison. 836 837 Returns: 838 Dict mapping (fname_j, fname_k) -> float. 839 """ 840 pairs = self.interaction_tokens.pair_indices() 841 weights = self.output_weights.detach().cpu().numpy() 842 return { 843 (self.feature_names[j], self.feature_names[k]): float(weights[idx]) 844 for idx, (j, k) in enumerate(pairs) 845 }
Return output weights w_{jk} for all pairs as a dict.
Large absolute weights indicate important interactions or main effects. Note: weights are on the scale of the linear predictor; compare w_{jk} * h_{jk} range rather than w_{jk} alone for fair comparison.
Returns:
Dict mapping (fname_j, fname_k) -> float.
852class PINEnsemble: 853 """ 854 Ensemble of PINModel instances trained with different random seeds. 855 856 Ensemble averaging is the primary regularisation strategy in the PIN paper. 857 With n=10 runs, ensemble PIN achieves the best published result on French MTPL. 858 859 Args: 860 n_models: Number of models in the ensemble. 861 **kwargs: Passed to each PINModel constructor. 862 """ 863 864 def __init__(self, n_models: int = 10, **kwargs) -> None: 865 self.n_models = n_models 866 self.kwargs = kwargs 867 self.models: List[PINModel] = [] 868 self._is_fitted = False 869 870 def fit( 871 self, 872 X_train, 873 y_train: np.ndarray, 874 exposure: Optional[np.ndarray] = None, 875 X_val=None, 876 y_val: Optional[np.ndarray] = None, 877 exposure_val: Optional[np.ndarray] = None, 878 verbose: bool = False, 879 ) -> "PINEnsemble": 880 """ 881 Fit all models with different seeds. 882 883 Args: 884 X_train: Training features. 885 y_train: Observed frequency. 886 exposure: Training exposure. 887 X_val: Validation features (optional). 888 y_val: Validation targets. 889 exposure_val: Validation exposure. 890 verbose: Print per-model progress. 891 892 Returns: 893 self. 894 """ 895 self.models = [] 896 for i in range(self.n_models): 897 seed = self.kwargs.get("random_seed", 42) + i 898 kw = {**self.kwargs, "random_seed": seed} 899 model = PINModel(**kw) 900 print(f"[Ensemble] Fitting model {i+1}/{self.n_models} (seed={seed})") 901 model.fit( 902 X_train, 903 y_train, 904 exposure=exposure, 905 X_val=X_val, 906 y_val=y_val, 907 exposure_val=exposure_val, 908 verbose=verbose, 909 ) 910 self.models.append(model) 911 912 self._is_fitted = True 913 return self 914 915 def predict( 916 self, 917 X, 918 exposure: Optional[np.ndarray] = None, 919 ) -> np.ndarray: 920 """ 921 Average predictions across all models. 922 923 Args: 924 X: Features. 925 exposure: Per-sample exposure. 926 927 Returns: 928 Ensemble mean prediction, shape (n,). 929 """ 930 if not self._is_fitted: 931 raise RuntimeError("Call fit() before predict().") 932 933 preds = np.stack([m.predict(X, exposure=exposure) for m in self.models], axis=0) 934 return preds.mean(axis=0) 935 936 def predict_std( 937 self, 938 X, 939 exposure: Optional[np.ndarray] = None, 940 ) -> np.ndarray: 941 """ 942 Standard deviation of predictions across models (epistemic uncertainty). 943 944 Returns: 945 Shape (n,). 946 """ 947 if not self._is_fitted: 948 raise RuntimeError("Call fit() before predict_std().") 949 950 preds = np.stack([m.predict(X, exposure=exposure) for m in self.models], axis=0) 951 return preds.std(axis=0) 952 953 def shapley_values( 954 self, 955 X, 956 X_background, 957 n_background: int = 100, 958 ) -> Dict[str, np.ndarray]: 959 """ 960 Average Shapley values across ensemble members. 961 962 Returns: 963 Dict mapping feature_name -> (n_test,) array. 964 """ 965 if not self._is_fitted: 966 raise RuntimeError("Call fit() before shapley_values().") 967 968 all_shaps = [m.shapley_values(X, X_background, n_background) for m in self.models] 969 feature_names = list(all_shaps[0].keys()) 970 971 return { 972 fname: np.stack([s[fname] for s in all_shaps], axis=0).mean(axis=0) 973 for fname in feature_names 974 } 975 976 def interaction_weights(self) -> Dict[Tuple[str, str], float]: 977 """ 978 Mean absolute interaction weights across ensemble. 979 980 Returns: 981 Dict mapping (fname_j, fname_k) -> mean |w_{jk}|. 982 """ 983 from collections import defaultdict 984 sums: Dict = defaultdict(float) 985 for model in self.models: 986 for pair, w in model.interaction_weights().items(): 987 sums[pair] += abs(w) 988 return {pair: v / self.n_models for pair, v in sums.items()}
Ensemble of PINModel instances trained with different random seeds.
Ensemble averaging is the primary regularisation strategy in the PIN paper. With n=10 runs, ensemble PIN achieves the best published result on French MTPL.
Arguments:
- n_models: Number of models in the ensemble.
- **kwargs: Passed to each PINModel constructor.
870 def fit( 871 self, 872 X_train, 873 y_train: np.ndarray, 874 exposure: Optional[np.ndarray] = None, 875 X_val=None, 876 y_val: Optional[np.ndarray] = None, 877 exposure_val: Optional[np.ndarray] = None, 878 verbose: bool = False, 879 ) -> "PINEnsemble": 880 """ 881 Fit all models with different seeds. 882 883 Args: 884 X_train: Training features. 885 y_train: Observed frequency. 886 exposure: Training exposure. 887 X_val: Validation features (optional). 888 y_val: Validation targets. 889 exposure_val: Validation exposure. 890 verbose: Print per-model progress. 891 892 Returns: 893 self. 894 """ 895 self.models = [] 896 for i in range(self.n_models): 897 seed = self.kwargs.get("random_seed", 42) + i 898 kw = {**self.kwargs, "random_seed": seed} 899 model = PINModel(**kw) 900 print(f"[Ensemble] Fitting model {i+1}/{self.n_models} (seed={seed})") 901 model.fit( 902 X_train, 903 y_train, 904 exposure=exposure, 905 X_val=X_val, 906 y_val=y_val, 907 exposure_val=exposure_val, 908 verbose=verbose, 909 ) 910 self.models.append(model) 911 912 self._is_fitted = True 913 return self
Fit all models with different seeds.
Arguments:
- X_train: Training features.
- y_train: Observed frequency.
- exposure: Training exposure.
- X_val: Validation features (optional).
- y_val: Validation targets.
- exposure_val: Validation exposure.
- verbose: Print per-model progress.
Returns:
self.
915 def predict( 916 self, 917 X, 918 exposure: Optional[np.ndarray] = None, 919 ) -> np.ndarray: 920 """ 921 Average predictions across all models. 922 923 Args: 924 X: Features. 925 exposure: Per-sample exposure. 926 927 Returns: 928 Ensemble mean prediction, shape (n,). 929 """ 930 if not self._is_fitted: 931 raise RuntimeError("Call fit() before predict().") 932 933 preds = np.stack([m.predict(X, exposure=exposure) for m in self.models], axis=0) 934 return preds.mean(axis=0)
Average predictions across all models.
Arguments:
- X: Features.
- exposure: Per-sample exposure.
Returns:
Ensemble mean prediction, shape (n,).
936 def predict_std( 937 self, 938 X, 939 exposure: Optional[np.ndarray] = None, 940 ) -> np.ndarray: 941 """ 942 Standard deviation of predictions across models (epistemic uncertainty). 943 944 Returns: 945 Shape (n,). 946 """ 947 if not self._is_fitted: 948 raise RuntimeError("Call fit() before predict_std().") 949 950 preds = np.stack([m.predict(X, exposure=exposure) for m in self.models], axis=0) 951 return preds.std(axis=0)
Standard deviation of predictions across models (epistemic uncertainty).
Returns:
Shape (n,).
953 def shapley_values( 954 self, 955 X, 956 X_background, 957 n_background: int = 100, 958 ) -> Dict[str, np.ndarray]: 959 """ 960 Average Shapley values across ensemble members. 961 962 Returns: 963 Dict mapping feature_name -> (n_test,) array. 964 """ 965 if not self._is_fitted: 966 raise RuntimeError("Call fit() before shapley_values().") 967 968 all_shaps = [m.shapley_values(X, X_background, n_background) for m in self.models] 969 feature_names = list(all_shaps[0].keys()) 970 971 return { 972 fname: np.stack([s[fname] for s in all_shaps], axis=0).mean(axis=0) 973 for fname in feature_names 974 }
Average Shapley values across ensemble members.
Returns:
Dict mapping feature_name -> (n_test,) array.
976 def interaction_weights(self) -> Dict[Tuple[str, str], float]: 977 """ 978 Mean absolute interaction weights across ensemble. 979 980 Returns: 981 Dict mapping (fname_j, fname_k) -> mean |w_{jk}|. 982 """ 983 from collections import defaultdict 984 sums: Dict = defaultdict(float) 985 for model in self.models: 986 for pair, w in model.interaction_weights().items(): 987 sums[pair] += abs(w) 988 return {pair: v / self.n_models for pair, v in sums.items()}
Mean absolute interaction weights across ensemble.
Returns:
Dict mapping (fname_j, fname_k) -> mean |w_{jk}|.
26class PINDiagnostics: 27 """ 28 Visualisation tools for a fitted PINModel. 29 30 Args: 31 model: Fitted PINModel instance. 32 """ 33 34 def __init__(self, model) -> None: 35 self.model = model 36 37 def interaction_heatmap( 38 self, 39 figsize: Tuple[int, int] = (8, 7), 40 cmap: str = "RdBu_r", 41 title: str = "PIN Interaction Weights |w_{jk}|", 42 ax=None, 43 ): 44 """ 45 Heatmap of interaction weight magnitudes. 46 47 Rows and columns are features. The upper triangle (j<k) shows 48 interactions; the diagonal shows main effects. 49 50 Large values indicate pairs that contribute more variance to the 51 linear predictor. Note: weight magnitude alone doesn't measure 52 importance — the range of h_{jk} also matters. Use 53 weighted_importance() for a more meaningful ranking. 54 55 Args: 56 figsize: Figure size. 57 cmap: Colormap. 58 title: Plot title. 59 ax: Existing axes to draw on (optional). 60 61 Returns: 62 (fig, ax) tuple. 63 """ 64 plt = _get_plt() 65 weights = self.model.interaction_weights() 66 feature_names = self.model.feature_names 67 q = len(feature_names) 68 69 matrix = np.zeros((q, q)) 70 for (fname_j, fname_k), w in weights.items(): 71 j = feature_names.index(fname_j) 72 k = feature_names.index(fname_k) 73 matrix[j, k] = abs(w) 74 if j != k: 75 matrix[k, j] = abs(w) 76 77 if ax is None: 78 fig, ax = plt.subplots(figsize=figsize) 79 else: 80 fig = ax.get_figure() 81 82 im = ax.imshow(matrix, cmap=cmap, aspect="auto") 83 plt.colorbar(im, ax=ax, label="|w_{jk}|") 84 85 ax.set_xticks(range(q)) 86 ax.set_yticks(range(q)) 87 ax.set_xticklabels(feature_names, rotation=45, ha="right") 88 ax.set_yticklabels(feature_names) 89 ax.set_title(title) 90 91 for i in range(q): 92 for j in range(q): 93 ax.text( 94 j, i, f"{matrix[i, j]:.3f}", 95 ha="center", va="center", fontsize=7, 96 ) 97 98 fig.tight_layout() 99 return fig, ax 100 101 def weighted_importance( 102 self, 103 X_background, 104 top_n: Optional[int] = None, 105 figsize: Tuple[int, int] = (8, 5), 106 ax=None, 107 ): 108 """ 109 Rank interaction pairs by the range of w_{jk} * h_{jk}(x). 110 111 The range (max - min) over the background data measures how much 112 the pair actually varies — a large weight with a flat h_{jk} adds 113 almost nothing. This gives a fairer importance metric than |w_{jk}|. 114 115 Args: 116 X_background: Background data for evaluating h_{jk}. 117 top_n: Plot only top N pairs. None = all. 118 figsize: Figure size. 119 ax: Existing axes. 120 121 Returns: 122 (fig, ax, importance_dict). 123 """ 124 plt = _get_plt() 125 contribs = self.model.pair_contributions(X_background) 126 127 importance = {} 128 for (fname_j, fname_k), vals in contribs.items(): 129 label = f"{fname_j}" if fname_j == fname_k else f"{fname_j} x {fname_k}" 130 importance[label] = float(vals.max() - vals.min()) 131 132 sorted_items = sorted(importance.items(), key=lambda x: x[1], reverse=True) 133 if top_n is not None: 134 sorted_items = sorted_items[:top_n] 135 136 labels = [it[0] for it in sorted_items] 137 values = [it[1] for it in sorted_items] 138 139 if ax is None: 140 fig, ax = plt.subplots(figsize=figsize) 141 else: 142 fig = ax.get_figure() 143 144 colors = [ 145 "#2196F3" if " x " not in label else "#FF5722" 146 for label in labels 147 ] 148 149 ax.barh(labels[::-1], values[::-1], color=colors[::-1]) 150 ax.set_xlabel("Range of w_{jk} * h_{jk}(x) on background") 151 ax.set_title("PIN Pair Importance (main effects=blue, interactions=orange)") 152 ax.grid(axis="x", alpha=0.3) 153 fig.tight_layout() 154 155 return fig, ax, dict(sorted_items) 156 157 def plot_main_effect( 158 self, 159 feature: str, 160 X_background, 161 n_grid: int = 100, 162 figsize: Tuple[int, int] = (6, 4), 163 ax=None, 164 color: str = "#2196F3", 165 ): 166 """ 167 Plot main effect curve for a single feature. 168 169 Evaluates the diagonal pair term w_{jj} * h_{jj}(x_j) over a grid 170 of feature values while fixing other features to background means. 171 172 Args: 173 feature: Feature name. 174 X_background: Background data. 175 n_grid: Grid resolution. 176 figsize: Figure size. 177 ax: Existing axes. 178 color: Line colour. 179 180 Returns: 181 (fig, ax) tuple. 182 """ 183 plt = _get_plt() 184 effects = self.model.main_effects(X_background, n_grid=n_grid) 185 186 if feature not in effects: 187 raise ValueError(f"Feature '{feature}' not found. Available: {list(effects.keys())}") 188 189 grid_vals, effect_vals = effects[feature] 190 191 if ax is None: 192 fig, ax = plt.subplots(figsize=figsize) 193 else: 194 fig = ax.get_figure() 195 196 spec = self.model.features[feature] 197 if isinstance(spec, int): 198 ax.bar(grid_vals, effect_vals, color=color, alpha=0.8) 199 ax.set_xlabel(f"{feature} (category)") 200 else: 201 ax.plot(grid_vals, effect_vals, color=color, linewidth=2) 202 ax.fill_between(grid_vals, effect_vals, alpha=0.15, color=color) 203 ax.set_xlabel(feature) 204 205 ax.axhline(0, color="gray", linestyle="--", alpha=0.5) 206 ax.set_ylabel("w_{jj} * h_{jj}(x) (linear predictor scale)") 207 ax.set_title(f"Main Effect: {feature}") 208 ax.grid(alpha=0.3) 209 fig.tight_layout() 210 211 return fig, ax 212 213 def plot_surface( 214 self, 215 feature_j: str, 216 feature_k: str, 217 X_background, 218 n_grid: int = 30, 219 figsize: Tuple[int, int] = (7, 5), 220 cmap: str = "RdBu_r", 221 ax=None, 222 ): 223 """ 224 Plot 2D interaction surface for a feature pair. 225 226 Evaluates w_{jk} * h_{jk}(x_j, x_k) over a grid. Points with 227 high absolute values indicate combinations where the interaction 228 adds or subtracts substantially from the linear predictor. 229 230 Args: 231 feature_j: First feature (x-axis for continuous). 232 feature_k: Second feature (y-axis or color dimension). 233 X_background: Background data for axis range estimation. 234 n_grid: Grid resolution per axis. 235 figsize: Figure size. 236 cmap: Colormap. 237 ax: Existing axes. 238 239 Returns: 240 (fig, ax) tuple. 241 """ 242 plt = _get_plt() 243 surfaces = self.model.interaction_surfaces( 244 X_background, 245 n_grid=n_grid, 246 pairs=[(feature_j, feature_k)], 247 ) 248 249 key = (feature_j, feature_k) 250 if key not in surfaces: 251 key = (feature_k, feature_j) 252 253 if key not in surfaces: 254 raise ValueError( 255 f"Pair ({feature_j}, {feature_k}) not found in surfaces. " 256 "Ensure both features exist and j < k in feature order." 257 ) 258 259 surf_data = surfaces[key] 260 grid_j = surf_data["grid_j"] 261 grid_k = surf_data["grid_k"] 262 surface = surf_data["surface"] 263 264 if ax is None: 265 fig, ax = plt.subplots(figsize=figsize) 266 else: 267 fig = ax.get_figure() 268 269 vmax = np.abs(surface).max() 270 vmin = -vmax 271 272 spec_j = self.model.features[feature_j] 273 spec_k = self.model.features[feature_k] 274 275 if isinstance(spec_k, int): 276 for cat_idx in range(grid_k.shape[0]): 277 ax.plot( 278 grid_j, 279 surface[:, cat_idx], 280 label=f"{feature_k}={int(grid_k[cat_idx])}", 281 alpha=0.8, 282 ) 283 ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8) 284 ax.set_xlabel(feature_j) 285 ax.set_ylabel("w * h (interaction)") 286 else: 287 im = ax.pcolormesh( 288 grid_j, grid_k, surface.T, 289 cmap=cmap, vmin=vmin, vmax=vmax, 290 ) 291 plt.colorbar(im, ax=ax, label="w_{jk} * h_{jk}(x)") 292 ax.set_xlabel(feature_j) 293 ax.set_ylabel(feature_k) 294 295 ax.set_title(f"Interaction Surface: {feature_j} x {feature_k}") 296 fig.tight_layout() 297 298 return fig, ax 299 300 def plot_training_history( 301 self, 302 figsize: Tuple[int, int] = (8, 4), 303 ax=None, 304 ): 305 """ 306 Plot training and validation loss curves. 307 308 Args: 309 figsize: Figure size. 310 ax: Existing axes. 311 312 Returns: 313 (fig, ax) tuple. 314 """ 315 plt = _get_plt() 316 history = self.model.train_history 317 318 if ax is None: 319 fig, ax = plt.subplots(figsize=figsize) 320 else: 321 fig = ax.get_figure() 322 323 epochs = range(len(history["train_loss"])) 324 ax.plot(epochs, history["train_loss"], label="Train", color="#2196F3") 325 ax.plot(epochs, history["val_loss"], label="Validation", color="#FF5722") 326 ax.set_xlabel("Epoch") 327 ax.set_ylabel("Loss (deviance)") 328 ax.set_title("PIN Training History") 329 ax.legend() 330 ax.grid(alpha=0.3) 331 fig.tight_layout() 332 333 return fig, ax 334 335 def summary(self, X_background=None) -> str: 336 """ 337 Print a text summary of the fitted model. 338 339 Args: 340 X_background: Optional background data for pair importance ranking. 341 342 Returns: 343 Summary string. 344 """ 345 model = self.model 346 lines = [ 347 "=" * 60, 348 "PIN Model Summary", 349 "=" * 60, 350 f"Features: {model.q}", 351 f"Parameters: {model.count_parameters():,}", 352 f"Loss: {model.loss_name}", 353 f"Embedding: d={model.embedding_dim}, d'={model.hidden_dim}, d0={model.token_dim}", 354 f"Shared net: d1={model.shared_dims[0]}, d2={model.shared_dims[1]}", 355 f"n_pairs: {model.q * (model.q + 1) // 2} " 356 f"({model.q} main + {model.q*(model.q-1)//2} interactions)", 357 "", 358 "Feature list:", 359 ] 360 for name in model.feature_names: 361 spec = model.features[name] 362 ftype = "continuous" if spec == "continuous" else f"categorical ({spec} levels)" 363 lines.append(f" {name}: {ftype}") 364 365 lines.append("") 366 lines.append("Output weights (|w_{jk}|, top 10):") 367 weights = model.interaction_weights() 368 sorted_w = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True)[:10] 369 for (fj, fk), w in sorted_w: 370 tag = "(main)" if fj == fk else "(interaction)" 371 lines.append(f" {fj} x {fk}: {w:+.4f} {tag}") 372 373 if model.train_history["val_loss"]: 374 best_val = min(model.train_history["val_loss"]) 375 lines.append(f"\nBest val loss: {best_val:.6f}") 376 lines.append(f"Epochs run: {len(model.train_history['val_loss'])}") 377 378 lines.append("=" * 60) 379 txt = "\n".join(lines) 380 print(txt) 381 return txt
Visualisation tools for a fitted PINModel.
Arguments:
- model: Fitted PINModel instance.
37 def interaction_heatmap( 38 self, 39 figsize: Tuple[int, int] = (8, 7), 40 cmap: str = "RdBu_r", 41 title: str = "PIN Interaction Weights |w_{jk}|", 42 ax=None, 43 ): 44 """ 45 Heatmap of interaction weight magnitudes. 46 47 Rows and columns are features. The upper triangle (j<k) shows 48 interactions; the diagonal shows main effects. 49 50 Large values indicate pairs that contribute more variance to the 51 linear predictor. Note: weight magnitude alone doesn't measure 52 importance — the range of h_{jk} also matters. Use 53 weighted_importance() for a more meaningful ranking. 54 55 Args: 56 figsize: Figure size. 57 cmap: Colormap. 58 title: Plot title. 59 ax: Existing axes to draw on (optional). 60 61 Returns: 62 (fig, ax) tuple. 63 """ 64 plt = _get_plt() 65 weights = self.model.interaction_weights() 66 feature_names = self.model.feature_names 67 q = len(feature_names) 68 69 matrix = np.zeros((q, q)) 70 for (fname_j, fname_k), w in weights.items(): 71 j = feature_names.index(fname_j) 72 k = feature_names.index(fname_k) 73 matrix[j, k] = abs(w) 74 if j != k: 75 matrix[k, j] = abs(w) 76 77 if ax is None: 78 fig, ax = plt.subplots(figsize=figsize) 79 else: 80 fig = ax.get_figure() 81 82 im = ax.imshow(matrix, cmap=cmap, aspect="auto") 83 plt.colorbar(im, ax=ax, label="|w_{jk}|") 84 85 ax.set_xticks(range(q)) 86 ax.set_yticks(range(q)) 87 ax.set_xticklabels(feature_names, rotation=45, ha="right") 88 ax.set_yticklabels(feature_names) 89 ax.set_title(title) 90 91 for i in range(q): 92 for j in range(q): 93 ax.text( 94 j, i, f"{matrix[i, j]:.3f}", 95 ha="center", va="center", fontsize=7, 96 ) 97 98 fig.tight_layout() 99 return fig, ax
Heatmap of interaction weight magnitudes.
Rows and columns are features. The upper triangle (j Large values indicate pairs that contribute more variance to the
linear predictor. Note: weight magnitude alone doesn't measure
importance — the range of h_{jk} also matters. Use
weighted_importance() for a more meaningful ranking. (fig, ax) tuple.Arguments:
Returns:
101 def weighted_importance( 102 self, 103 X_background, 104 top_n: Optional[int] = None, 105 figsize: Tuple[int, int] = (8, 5), 106 ax=None, 107 ): 108 """ 109 Rank interaction pairs by the range of w_{jk} * h_{jk}(x). 110 111 The range (max - min) over the background data measures how much 112 the pair actually varies — a large weight with a flat h_{jk} adds 113 almost nothing. This gives a fairer importance metric than |w_{jk}|. 114 115 Args: 116 X_background: Background data for evaluating h_{jk}. 117 top_n: Plot only top N pairs. None = all. 118 figsize: Figure size. 119 ax: Existing axes. 120 121 Returns: 122 (fig, ax, importance_dict). 123 """ 124 plt = _get_plt() 125 contribs = self.model.pair_contributions(X_background) 126 127 importance = {} 128 for (fname_j, fname_k), vals in contribs.items(): 129 label = f"{fname_j}" if fname_j == fname_k else f"{fname_j} x {fname_k}" 130 importance[label] = float(vals.max() - vals.min()) 131 132 sorted_items = sorted(importance.items(), key=lambda x: x[1], reverse=True) 133 if top_n is not None: 134 sorted_items = sorted_items[:top_n] 135 136 labels = [it[0] for it in sorted_items] 137 values = [it[1] for it in sorted_items] 138 139 if ax is None: 140 fig, ax = plt.subplots(figsize=figsize) 141 else: 142 fig = ax.get_figure() 143 144 colors = [ 145 "#2196F3" if " x " not in label else "#FF5722" 146 for label in labels 147 ] 148 149 ax.barh(labels[::-1], values[::-1], color=colors[::-1]) 150 ax.set_xlabel("Range of w_{jk} * h_{jk}(x) on background") 151 ax.set_title("PIN Pair Importance (main effects=blue, interactions=orange)") 152 ax.grid(axis="x", alpha=0.3) 153 fig.tight_layout() 154 155 return fig, ax, dict(sorted_items)
Rank interaction pairs by the range of w_{jk} * h_{jk}(x).
The range (max - min) over the background data measures how much the pair actually varies — a large weight with a flat h_{jk} adds almost nothing. This gives a fairer importance metric than |w_{jk}|.
Arguments:
- X_background: Background data for evaluating h_{jk}.
- top_n: Plot only top N pairs. None = all.
- figsize: Figure size.
- ax: Existing axes.
Returns:
(fig, ax, importance_dict).
157 def plot_main_effect( 158 self, 159 feature: str, 160 X_background, 161 n_grid: int = 100, 162 figsize: Tuple[int, int] = (6, 4), 163 ax=None, 164 color: str = "#2196F3", 165 ): 166 """ 167 Plot main effect curve for a single feature. 168 169 Evaluates the diagonal pair term w_{jj} * h_{jj}(x_j) over a grid 170 of feature values while fixing other features to background means. 171 172 Args: 173 feature: Feature name. 174 X_background: Background data. 175 n_grid: Grid resolution. 176 figsize: Figure size. 177 ax: Existing axes. 178 color: Line colour. 179 180 Returns: 181 (fig, ax) tuple. 182 """ 183 plt = _get_plt() 184 effects = self.model.main_effects(X_background, n_grid=n_grid) 185 186 if feature not in effects: 187 raise ValueError(f"Feature '{feature}' not found. Available: {list(effects.keys())}") 188 189 grid_vals, effect_vals = effects[feature] 190 191 if ax is None: 192 fig, ax = plt.subplots(figsize=figsize) 193 else: 194 fig = ax.get_figure() 195 196 spec = self.model.features[feature] 197 if isinstance(spec, int): 198 ax.bar(grid_vals, effect_vals, color=color, alpha=0.8) 199 ax.set_xlabel(f"{feature} (category)") 200 else: 201 ax.plot(grid_vals, effect_vals, color=color, linewidth=2) 202 ax.fill_between(grid_vals, effect_vals, alpha=0.15, color=color) 203 ax.set_xlabel(feature) 204 205 ax.axhline(0, color="gray", linestyle="--", alpha=0.5) 206 ax.set_ylabel("w_{jj} * h_{jj}(x) (linear predictor scale)") 207 ax.set_title(f"Main Effect: {feature}") 208 ax.grid(alpha=0.3) 209 fig.tight_layout() 210 211 return fig, ax
Plot main effect curve for a single feature.
Evaluates the diagonal pair term w_{jj} * h_{jj}(x_j) over a grid of feature values while fixing other features to background means.
Arguments:
- feature: Feature name.
- X_background: Background data.
- n_grid: Grid resolution.
- figsize: Figure size.
- ax: Existing axes.
- color: Line colour.
Returns:
(fig, ax) tuple.
213 def plot_surface( 214 self, 215 feature_j: str, 216 feature_k: str, 217 X_background, 218 n_grid: int = 30, 219 figsize: Tuple[int, int] = (7, 5), 220 cmap: str = "RdBu_r", 221 ax=None, 222 ): 223 """ 224 Plot 2D interaction surface for a feature pair. 225 226 Evaluates w_{jk} * h_{jk}(x_j, x_k) over a grid. Points with 227 high absolute values indicate combinations where the interaction 228 adds or subtracts substantially from the linear predictor. 229 230 Args: 231 feature_j: First feature (x-axis for continuous). 232 feature_k: Second feature (y-axis or color dimension). 233 X_background: Background data for axis range estimation. 234 n_grid: Grid resolution per axis. 235 figsize: Figure size. 236 cmap: Colormap. 237 ax: Existing axes. 238 239 Returns: 240 (fig, ax) tuple. 241 """ 242 plt = _get_plt() 243 surfaces = self.model.interaction_surfaces( 244 X_background, 245 n_grid=n_grid, 246 pairs=[(feature_j, feature_k)], 247 ) 248 249 key = (feature_j, feature_k) 250 if key not in surfaces: 251 key = (feature_k, feature_j) 252 253 if key not in surfaces: 254 raise ValueError( 255 f"Pair ({feature_j}, {feature_k}) not found in surfaces. " 256 "Ensure both features exist and j < k in feature order." 257 ) 258 259 surf_data = surfaces[key] 260 grid_j = surf_data["grid_j"] 261 grid_k = surf_data["grid_k"] 262 surface = surf_data["surface"] 263 264 if ax is None: 265 fig, ax = plt.subplots(figsize=figsize) 266 else: 267 fig = ax.get_figure() 268 269 vmax = np.abs(surface).max() 270 vmin = -vmax 271 272 spec_j = self.model.features[feature_j] 273 spec_k = self.model.features[feature_k] 274 275 if isinstance(spec_k, int): 276 for cat_idx in range(grid_k.shape[0]): 277 ax.plot( 278 grid_j, 279 surface[:, cat_idx], 280 label=f"{feature_k}={int(grid_k[cat_idx])}", 281 alpha=0.8, 282 ) 283 ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8) 284 ax.set_xlabel(feature_j) 285 ax.set_ylabel("w * h (interaction)") 286 else: 287 im = ax.pcolormesh( 288 grid_j, grid_k, surface.T, 289 cmap=cmap, vmin=vmin, vmax=vmax, 290 ) 291 plt.colorbar(im, ax=ax, label="w_{jk} * h_{jk}(x)") 292 ax.set_xlabel(feature_j) 293 ax.set_ylabel(feature_k) 294 295 ax.set_title(f"Interaction Surface: {feature_j} x {feature_k}") 296 fig.tight_layout() 297 298 return fig, ax
Plot 2D interaction surface for a feature pair.
Evaluates w_{jk} * h_{jk}(x_j, x_k) over a grid. Points with high absolute values indicate combinations where the interaction adds or subtracts substantially from the linear predictor.
Arguments:
- feature_j: First feature (x-axis for continuous).
- feature_k: Second feature (y-axis or color dimension).
- X_background: Background data for axis range estimation.
- n_grid: Grid resolution per axis.
- figsize: Figure size.
- cmap: Colormap.
- ax: Existing axes.
Returns:
(fig, ax) tuple.
300 def plot_training_history( 301 self, 302 figsize: Tuple[int, int] = (8, 4), 303 ax=None, 304 ): 305 """ 306 Plot training and validation loss curves. 307 308 Args: 309 figsize: Figure size. 310 ax: Existing axes. 311 312 Returns: 313 (fig, ax) tuple. 314 """ 315 plt = _get_plt() 316 history = self.model.train_history 317 318 if ax is None: 319 fig, ax = plt.subplots(figsize=figsize) 320 else: 321 fig = ax.get_figure() 322 323 epochs = range(len(history["train_loss"])) 324 ax.plot(epochs, history["train_loss"], label="Train", color="#2196F3") 325 ax.plot(epochs, history["val_loss"], label="Validation", color="#FF5722") 326 ax.set_xlabel("Epoch") 327 ax.set_ylabel("Loss (deviance)") 328 ax.set_title("PIN Training History") 329 ax.legend() 330 ax.grid(alpha=0.3) 331 fig.tight_layout() 332 333 return fig, ax
Plot training and validation loss curves.
Arguments:
- figsize: Figure size.
- ax: Existing axes.
Returns:
(fig, ax) tuple.
335 def summary(self, X_background=None) -> str: 336 """ 337 Print a text summary of the fitted model. 338 339 Args: 340 X_background: Optional background data for pair importance ranking. 341 342 Returns: 343 Summary string. 344 """ 345 model = self.model 346 lines = [ 347 "=" * 60, 348 "PIN Model Summary", 349 "=" * 60, 350 f"Features: {model.q}", 351 f"Parameters: {model.count_parameters():,}", 352 f"Loss: {model.loss_name}", 353 f"Embedding: d={model.embedding_dim}, d'={model.hidden_dim}, d0={model.token_dim}", 354 f"Shared net: d1={model.shared_dims[0]}, d2={model.shared_dims[1]}", 355 f"n_pairs: {model.q * (model.q + 1) // 2} " 356 f"({model.q} main + {model.q*(model.q-1)//2} interactions)", 357 "", 358 "Feature list:", 359 ] 360 for name in model.feature_names: 361 spec = model.features[name] 362 ftype = "continuous" if spec == "continuous" else f"categorical ({spec} levels)" 363 lines.append(f" {name}: {ftype}") 364 365 lines.append("") 366 lines.append("Output weights (|w_{jk}|, top 10):") 367 weights = model.interaction_weights() 368 sorted_w = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True)[:10] 369 for (fj, fk), w in sorted_w: 370 tag = "(main)" if fj == fk else "(interaction)" 371 lines.append(f" {fj} x {fk}: {w:+.4f} {tag}") 372 373 if model.train_history["val_loss"]: 374 best_val = min(model.train_history["val_loss"]) 375 lines.append(f"\nBest val loss: {best_val:.6f}") 376 lines.append(f"Epochs run: {len(model.train_history['val_loss'])}") 377 378 lines.append("=" * 60) 379 txt = "\n".join(lines) 380 print(txt) 381 return txt
Print a text summary of the fitted model.
Arguments:
- X_background: Optional background data for pair importance ranking.
Returns:
Summary string.
20def centered_hard_sigmoid(x: torch.Tensor) -> torch.Tensor: 21 """ 22 The activation used in PIN interaction units. 23 24 Defined in the paper as: 25 sigma_hard(x) = clamp((1 + x) / 2, 0, 1) 26 27 This maps: 28 x = 0 -> 0.5 (the 'centered' sense — zero input gives midpoint) 29 x = -1 -> 0.0 30 x = 1 -> 1.0 31 32 Note: NOT torch.nn.Hardsigmoid, which uses clamp((x+3)/6, 0, 1) — a 33 different shift chosen to match sigmoid(-3)=0.5 behaviour. The PIN paper 34 uses the simpler (1+x)/2 formulation. 35 36 Args: 37 x: Input tensor, any shape. 38 39 Returns: 40 Tensor of same shape, values in [0, 1]. 41 """ 42 return torch.clamp((1.0 + x) / 2.0, min=0.0, max=1.0)
The activation used in PIN interaction units.
Defined in the paper as:
sigma_hard(x) = clamp((1 + x) / 2, 0, 1)
This maps:
x = 0 -> 0.5 (the 'centered' sense — zero input gives midpoint) x = -1 -> 0.0 x = 1 -> 1.0
Note: NOT torch.nn.Hardsigmoid, which uses clamp((x+3)/6, 0, 1) — a different shift chosen to match sigmoid(-3)=0.5 behaviour. The PIN paper uses the simpler (1+x)/2 formulation.
Arguments:
- x: Input tensor, any shape.
Returns:
Tensor of same shape, values in [0, 1].