Edit on GitHub

insurance_gam.pin

insurance_gam.pin — Pairwise Interaction Networks subpackage.

Re-exports the full public API of the original insurance-pin package.

Requires the pin extra::

pip install insurance-gam[pin]
 1"""
 2insurance_gam.pin — Pairwise Interaction Networks subpackage.
 3
 4Re-exports the full public API of the original insurance-pin package.
 5
 6Requires the ``pin`` extra::
 7
 8    pip install insurance-gam[pin]
 9"""
10
11try:
12    from .model import PINModel, PINEnsemble
13    from .diagnostics import PINDiagnostics
14    from .networks import centered_hard_sigmoid
15except ImportError as e:
16    raise ImportError(
17        "insurance_gam.pin requires the pin extra. "
18        "Install with: pip install insurance-gam[pin]\n"
19        f"Original error: {e}"
20    ) from e
21
22__all__ = [
23    "PINModel",
24    "PINEnsemble",
25    "PINDiagnostics",
26    "centered_hard_sigmoid",
27]
class PINModel(torch.nn.modules.module.Module):
 60class PINModel(nn.Module):
 61    """
 62    Single Tree-like Pairwise Interaction Network.
 63
 64    Prediction:
 65        f_PIN(x) = exp( sum_{j<=k} w_{jk} * h_{jk}(x) + b )
 66
 67    Args:
 68        features: Dict mapping feature name to spec. Spec is 'continuous' or
 69            an int (number of categories). Order matters — features are indexed
 70            by position.
 71        embedding_dim: Feature embedding dimension d (default 10).
 72        hidden_dim: Hidden width for continuous embedding FNNs d' (default 20).
 73        token_dim: Interaction token dimension d0 (default 10).
 74        shared_dims: (d1, d2) widths for shared interaction network (default [30, 20]).
 75        loss: Loss name — 'poisson', 'gamma', or 'tweedie'.
 76        tweedie_p: Tweedie power (only used when loss='tweedie').
 77        lr: Adam learning rate (default 0.001).
 78        batch_size: Mini-batch size (default 128).
 79        max_epochs: Maximum training epochs (default 500).
 80        patience: Early stopping patience in epochs (default 20).
 81        lr_patience: ReduceLROnPlateau patience (default 5).
 82        lr_factor: ReduceLROnPlateau reduction factor (default 0.9).
 83        val_fraction: Fraction of training data for validation if X_val not given
 84            (default 0.1).
 85        device: Torch device string, or None to auto-detect.
 86        random_seed: Seed for reproducibility.
 87
 88    Examples:
 89        >>> model = PINModel(
 90        ...     features={"age": "continuous", "area": 5},
 91        ...     loss="poisson",
 92        ... )
 93    """
 94
 95    def __init__(
 96        self,
 97        features: FeatureSpec,
 98        embedding_dim: int = 10,
 99        hidden_dim: int = 20,
100        token_dim: int = 10,
101        shared_dims: Tuple[int, int] = (30, 20),
102        loss: str = "poisson",
103        tweedie_p: float = 1.5,
104        lr: float = 0.001,
105        batch_size: int = 128,
106        max_epochs: int = 500,
107        patience: int = 20,
108        lr_patience: int = 5,
109        lr_factor: float = 0.9,
110        val_fraction: float = 0.1,
111        device: Optional[str] = None,
112        random_seed: int = 42,
113    ) -> None:
114        super().__init__()
115
116        self.features = dict(features)
117        self.feature_names: List[str] = list(features.keys())
118        self.q = len(self.feature_names)
119        self.embedding_dim = embedding_dim
120        self.hidden_dim = hidden_dim
121        self.token_dim = token_dim
122        self.shared_dims = tuple(shared_dims)
123        self.loss_name = loss
124        self.tweedie_p = tweedie_p
125        self.lr = lr
126        self.batch_size = batch_size
127        self.max_epochs = max_epochs
128        self.patience = patience
129        self.lr_patience = lr_patience
130        self.lr_factor = lr_factor
131        self.val_fraction = val_fraction
132        self.random_seed = random_seed
133
134        # Resolve device
135        if device is None:
136            self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137        else:
138            self._device = torch.device(device)
139
140        # --- Sub-modules ---
141        self.feature_embeddings = FeatureEmbeddings(
142            features=features,
143            embedding_dim=embedding_dim,
144            hidden_dim=hidden_dim,
145        )
146        self.interaction_tokens = InteractionTokens(
147            n_features=self.q,
148            token_dim=token_dim,
149        )
150        self.shared_net = SharedInteractionNet(
151            embedding_dim=embedding_dim,
152            token_dim=token_dim,
153            layer1_dim=shared_dims[0],
154            layer2_dim=shared_dims[1],
155        )
156
157        n_pairs = self.q * (self.q + 1) // 2
158        # Output weights w_{jk} and bias b — linear combination of pair terms
159        self.output_weights = nn.Parameter(torch.randn(n_pairs) * 0.01)
160        self.output_bias = nn.Parameter(torch.zeros(1))
161
162        # Loss function
163        loss_kwargs = {"p": tweedie_p} if loss == "tweedie" else {}
164        self._loss_fn = get_loss(loss, **loss_kwargs)
165
166        # Weight initialisation: smaller scales for numerical stability
167        self._init_weights()
168
169        # Training state
170        self._is_fitted = False
171        self.train_history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []}
172
173        # Centering offsets (set post-hoc after fitting)
174        # h_{jk}^centered(x) = h_{jk}(x) - mean_train[h_{jk}]
175        # We store mean_train[w_{jk} * h_{jk}] per pair for efficiency
176        self._pair_means: Optional[torch.Tensor] = None
177
178    # ------------------------------------------------------------------
179    # Weight initialisation
180    # ------------------------------------------------------------------
181
182    def _init_weights(self) -> None:
183        """
184        Initialise all linear layers to smaller scales for training stability.
185
186        Default PyTorch Kaiming init can produce large initial activations
187        when many layers are composed. With smaller scales, early training
188        is more stable even on very small datasets.
189        """
190        for module in self.modules():
191            if isinstance(module, nn.Linear):
192                nn.init.normal_(module.weight, mean=0.0, std=0.01)
193                if module.bias is not None:
194                    nn.init.zeros_(module.bias)
195            elif isinstance(module, nn.Embedding):
196                nn.init.normal_(module.weight, mean=0.0, std=0.01)
197
198    # ------------------------------------------------------------------
199    # Forward pass
200    # ------------------------------------------------------------------
201
202    def _compute_linear_predictor(
203        self,
204        x_dict: Dict[str, torch.Tensor],
205        apply_centering: bool = True,
206    ) -> torch.Tensor:
207        """
208        Compute sum_{j<=k} w_{jk} * h_{jk}(x) + b.
209
210        Args:
211            x_dict: Feature tensors on self._device.
212            apply_centering: Subtract pair means (set during fit) for identifiability.
213
214        Returns:
215            Shape (batch_size,).
216        """
217        # Embed all features once
218        embeddings: Dict[str, torch.Tensor] = self.feature_embeddings.embed_all(x_dict)
219
220        pairs = self.interaction_tokens.pair_indices()
221        terms = []
222
223        for pair_idx, (j, k) in enumerate(pairs):
224            fname_j = self.feature_names[j]
225            fname_k = self.feature_names[k]
226
227            phi_j = embeddings[fname_j]
228            phi_k = embeddings[fname_k]
229            token = self.interaction_tokens.get_token(j, k)
230
231            raw = self.shared_net(phi_j, phi_k, token)  # (batch, 1)
232            h = centered_hard_sigmoid(raw).squeeze(-1)   # (batch,)
233
234            w = self.output_weights[pair_idx]
235            term = w * h  # (batch,)
236            terms.append(term)
237
238        # Stack to (batch, n_pairs) then sum over pairs
239        all_terms = torch.stack(terms, dim=1)  # (batch, n_pairs)
240
241        if apply_centering and self._pair_means is not None:
242            all_terms = all_terms - self._pair_means.unsqueeze(0)
243
244        linear_pred = all_terms.sum(dim=1) + self.output_bias.squeeze()
245        return linear_pred
246
247    def forward(
248        self,
249        x_dict: Dict[str, torch.Tensor],
250        exposure: Optional[torch.Tensor] = None,
251    ) -> torch.Tensor:
252        """
253        Compute predictions f_PIN(x) = exp(linear_predictor) * exposure.
254
255        When exposure is provided, the raw model output is frequency (claims per
256        year) and multiplying by exposure gives expected claim count. For
257        frequency models, typically you'd pass exposure=None and let the caller
258        multiply; this method supports both modes.
259
260        Args:
261            x_dict: Feature tensors. Dict of feature_name -> (batch,) tensor.
262            exposure: Optional per-sample exposure, shape (batch,).
263
264        Returns:
265            Predicted frequency, shape (batch,).
266        """
267        eta = self._compute_linear_predictor(x_dict)
268        # Clamp to avoid exp overflow (GLM link stabilisation)
269        eta = torch.clamp(eta, min=-20.0, max=20.0)
270        mu = torch.exp(eta)
271        if exposure is not None:
272            mu = mu * exposure
273        return mu
274
275    # ------------------------------------------------------------------
276    # Data preparation
277    # ------------------------------------------------------------------
278
279    def _prepare_features(
280        self,
281        X: Union[Dict, "pl.DataFrame", "pd.DataFrame"],  # noqa: F821
282    ) -> Dict[str, torch.Tensor]:
283        """
284        Convert input data to a dict of tensors.
285
286        Accepts:
287        - Dict[str, np.ndarray]
288        - Dict[str, list]
289        - polars.DataFrame
290        - pandas.DataFrame
291        """
292        # Try polars first
293        try:
294            import polars as pl
295            if isinstance(X, pl.DataFrame):
296                return self._polars_to_dict(X)
297        except ImportError:
298            pass
299
300        # Try pandas
301        try:
302            import pandas as pd
303            if isinstance(X, pd.DataFrame):
304                return {col: X[col].to_numpy() for col in self.feature_names}
305        except ImportError:
306            pass
307
308        # Assume it's already a dict
309        if isinstance(X, dict):
310            return X
311
312        raise TypeError(
313            f"X must be a dict, polars.DataFrame, or pandas.DataFrame. Got {type(X)}."
314        )
315
316    def _polars_to_dict(self, df) -> Dict[str, np.ndarray]:
317        result = {}
318        for name in self.feature_names:
319            result[name] = df[name].to_numpy()
320        return result
321
322    def _to_device_dict(
323        self, x_dict: Dict[str, Union[np.ndarray, torch.Tensor]]
324    ) -> Dict[str, torch.Tensor]:
325        """Convert dict of arrays to tensors on device."""
326        result = {}
327        for name, arr in x_dict.items():
328            spec = self.features[name]
329            if spec == "continuous":
330                result[name] = _to_tensor(arr).to(self._device)
331            else:
332                result[name] = _to_long_tensor(arr).to(self._device)
333        return result
334
335    # ------------------------------------------------------------------
336    # Training
337    # ------------------------------------------------------------------
338
339    def fit(
340        self,
341        X_train,
342        y_train: np.ndarray,
343        exposure: Optional[np.ndarray] = None,
344        X_val=None,
345        y_val: Optional[np.ndarray] = None,
346        exposure_val: Optional[np.ndarray] = None,
347        verbose: bool = True,
348    ) -> "PINModel":
349        """
350        Fit the model.
351
352        Args:
353            X_train: Training features. Dict, polars.DataFrame, or pandas.DataFrame.
354            y_train: Observed frequency (claims / exposure), shape (n,).
355            exposure: Per-sample exposure (years at risk), shape (n,).
356            X_val: Validation features (optional). If None, 10% of training data
357                is reserved.
358            y_val: Validation targets.
359            exposure_val: Validation exposure.
360            verbose: Print training progress.
361
362        Returns:
363            self (for chaining).
364        """
365        torch.manual_seed(self.random_seed)
366        np.random.seed(self.random_seed)
367
368        self.to(self._device)
369
370        # Prepare training data
371        x_dict = self._prepare_features(X_train)
372        x_dict = self._to_device_dict(x_dict)
373        y_t = _to_tensor(y_train).to(self._device)
374        exp_t = (
375            _to_tensor(exposure).to(self._device)
376            if exposure is not None
377            else torch.ones_like(y_t)
378        )
379
380        n = y_t.shape[0]
381
382        # Build or reserve validation set (must happen before bias init so we
383        # use training-only data for the mean frequency estimate).
384        if X_val is not None:
385            x_val_dict = self._prepare_features(X_val)
386            x_val_dict = self._to_device_dict(x_val_dict)
387            y_val_t = _to_tensor(y_val).to(self._device)
388            exp_val_t = (
389                _to_tensor(exposure_val).to(self._device)
390                if exposure_val is not None
391                else torch.ones_like(y_val_t)
392            )
393        else:
394            # Reserve val_fraction from training
395            val_size = max(1, int(n * self.val_fraction))
396            perm = torch.randperm(n, device=self._device)
397            val_idx = perm[:val_size]
398            train_idx = perm[val_size:]
399
400            x_val_dict = {k: v[val_idx] for k, v in x_dict.items()}
401            y_val_t = y_t[val_idx]
402            exp_val_t = exp_t[val_idx]
403
404            x_dict = {k: v[train_idx] for k, v in x_dict.items()}
405            y_t = y_t[train_idx]
406            exp_t = exp_t[train_idx]
407
408        n_train = y_t.shape[0]
409
410        # Initialise bias to log(mean_frequency) of the TRAINING split only.
411        # Computing this from the full dataset (including validation) would
412        # be a mild form of data leakage — small in practice but wrong in
413        # principle. We compute it here, after the split.
414        with torch.no_grad():
415            mean_freq = y_t.mean().clamp(min=1e-8)
416            self.output_bias.fill_(torch.log(mean_freq).item())
417
418        optimizer = optim.Adam(self.parameters(), lr=self.lr)
419        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
420            optimizer,
421            mode="min",
422            factor=self.lr_factor,
423            patience=self.lr_patience,
424        )
425
426        best_val_loss = float("inf")
427        best_state = None
428        epochs_no_improve = 0
429
430        self.train_history = {"train_loss": [], "val_loss": []}
431
432        for epoch in range(self.max_epochs):
433            self.train()
434            # Shuffle training data
435            perm = torch.randperm(n_train, device=self._device)
436
437            epoch_loss = 0.0
438            n_batches = 0
439
440            for start in range(0, n_train, self.batch_size):
441                end = min(start + self.batch_size, n_train)
442                batch_idx = perm[start:end]
443
444                x_batch = {k: v[batch_idx] for k, v in x_dict.items()}
445                y_batch = y_t[batch_idx]
446                exp_batch = exp_t[batch_idx]
447
448                optimizer.zero_grad()
449                mu = self.forward(x_batch)
450                loss = self._loss_fn(mu, y_batch, exp_batch)
451                if torch.isnan(loss):
452                    continue  # skip NaN batches
453                loss.backward()
454                # Gradient clipping prevents exploding gradients in early training
455                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
456                # Zero NaN gradients before step
457                for p in self.parameters():
458                    if p.grad is not None and torch.isnan(p.grad).any():
459                        p.grad.zero_()
460                optimizer.step()
461                # Restore any NaN params to their pre-step values
462                for p in self.parameters():
463                    if torch.isnan(p).any():
464                        torch.nan_to_num_(p, nan=0.0)
465
466                epoch_loss += loss.item()
467                n_batches += 1
468
469            train_loss = epoch_loss / n_batches
470
471            # Validation
472            self.eval()
473            with torch.no_grad():
474                mu_val = self.forward(x_val_dict)
475                val_loss = self._loss_fn(mu_val, y_val_t, exp_val_t).item()
476
477            scheduler.step(val_loss)
478
479            self.train_history["train_loss"].append(train_loss)
480            self.train_history["val_loss"].append(val_loss)
481
482            if verbose and (epoch % 50 == 0 or epoch < 10):
483                lr_now = optimizer.param_groups[0]["lr"]
484                print(
485                    f"Epoch {epoch:4d} | train={train_loss:.6f} | val={val_loss:.6f} | lr={lr_now:.2e}"
486                )
487
488            # Early stopping
489            if val_loss < best_val_loss:
490                best_val_loss = val_loss
491                best_state = copy.deepcopy(self.state_dict())
492                epochs_no_improve = 0
493            else:
494                epochs_no_improve += 1
495                if epochs_no_improve >= self.patience:
496                    if verbose:
497                        print(f"Early stopping at epoch {epoch}. Best val={best_val_loss:.6f}")
498                    break
499
500        # Restore best weights
501        if best_state is not None:
502            self.load_state_dict(best_state)
503
504        # Post-hoc centering: compute pair means on training data
505        self._compute_pair_centering(x_dict)
506
507        self._is_fitted = True
508        return self
509
510    def _compute_pair_centering(
511        self, x_dict: Dict[str, torch.Tensor], batch_size: int = 1024
512    ) -> None:
513        """
514        Compute mean of w_{jk} * h_{jk}(x) over training data for each pair.
515
516        Centering ensures that the bias b absorbs the intercept and the
517        interaction terms have zero mean — making w_{jk} interpretable as
518        the amplitude of the interaction effect.
519
520        Stores result in self._pair_means, shape (n_pairs,).
521        """
522        self.eval()
523        n = x_dict[self.feature_names[0]].shape[0]
524        n_pairs = self.q * (self.q + 1) // 2
525        pairs = self.interaction_tokens.pair_indices()
526
527        pair_sums = torch.zeros(n_pairs, device=self._device)
528        n_seen = 0
529
530        with torch.no_grad():
531            for start in range(0, n, batch_size):
532                end = min(start + batch_size, n)
533                x_batch = {k: v[start:end] for k, v in x_dict.items()}
534                embeddings = self.feature_embeddings.embed_all(x_batch)
535                batch_n = end - start
536
537                for pair_idx, (j, k) in enumerate(pairs):
538                    fname_j = self.feature_names[j]
539                    fname_k = self.feature_names[k]
540                    phi_j = embeddings[fname_j]
541                    phi_k = embeddings[fname_k]
542                    token = self.interaction_tokens.get_token(j, k)
543                    raw = self.shared_net(phi_j, phi_k, token)
544                    h = centered_hard_sigmoid(raw).squeeze(-1)
545                    w = self.output_weights[pair_idx]
546                    pair_sums[pair_idx] += (w * h).sum()
547
548                n_seen += batch_n
549
550        self._pair_means = pair_sums / n_seen
551
552    # ------------------------------------------------------------------
553    # Prediction
554    # ------------------------------------------------------------------
555
556    def predict(
557        self,
558        X,
559        exposure: Optional[np.ndarray] = None,
560    ) -> np.ndarray:
561        """
562        Predict frequency (or expected claims if exposure given).
563
564        Args:
565            X: Features. Dict, polars.DataFrame, or pandas.DataFrame.
566            exposure: Per-sample exposure.
567
568        Returns:
569            Predictions as numpy array, shape (n,).
570        """
571        if not self._is_fitted:
572            raise RuntimeError("Call fit() before predict().")
573
574        self.eval()
575        x_dict = self._prepare_features(X)
576        x_dict = self._to_device_dict(x_dict)
577
578        exp_t = (
579            _to_tensor(exposure).to(self._device)
580            if exposure is not None
581            else None
582        )
583
584        with torch.no_grad():
585            mu = self.forward(x_dict, exposure=exp_t)
586
587        return mu.cpu().numpy()
588
589    # ------------------------------------------------------------------
590    # Interpretability surfaces
591    # ------------------------------------------------------------------
592
593    def pair_contributions(
594        self,
595        X,
596    ) -> Dict[Tuple[str, str], np.ndarray]:
597        """
598        Compute w_{jk} * h_{jk}(x) for each pair and each sample.
599
600        These are the additive components on the linear predictor scale.
601        Useful for understanding which pairs drive predictions.
602
603        Args:
604            X: Features.
605
606        Returns:
607            Dict mapping (fname_j, fname_k) -> array of shape (n,).
608        """
609        if not self._is_fitted:
610            raise RuntimeError("Call fit() before pair_contributions().")
611
612        self.eval()
613        x_dict = self._prepare_features(X)
614        x_dict = self._to_device_dict(x_dict)
615        pairs = self.interaction_tokens.pair_indices()
616        result = {}
617
618        with torch.no_grad():
619            embeddings = self.feature_embeddings.embed_all(x_dict)
620
621            for pair_idx, (j, k) in enumerate(pairs):
622                fname_j = self.feature_names[j]
623                fname_k = self.feature_names[k]
624                phi_j = embeddings[fname_j]
625                phi_k = embeddings[fname_k]
626                token = self.interaction_tokens.get_token(j, k)
627                raw = self.shared_net(phi_j, phi_k, token)
628                h = centered_hard_sigmoid(raw).squeeze(-1)
629                w = self.output_weights[pair_idx]
630                contrib = (w * h)
631                if self._pair_means is not None:
632                    contrib = contrib - self._pair_means[pair_idx]
633                result[(fname_j, fname_k)] = contrib.cpu().numpy()
634
635        return result
636
637    def main_effects(
638        self,
639        X_background,
640        n_grid: int = 100,
641    ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
642        """
643        Compute main effect curves for each feature.
644
645        For continuous features: evaluates h_{jj} over an evenly-spaced grid
646        while fixing all other features to background means.
647
648        Returns:
649            Dict mapping feature_name -> (grid_values, effect_values).
650            For categoricals: grid_values are integer category codes.
651        """
652        if not self._is_fitted:
653            raise RuntimeError("Call fit() before main_effects().")
654
655        self.eval()
656        bg_dict = self._prepare_features(X_background)
657        bg_dict = self._to_device_dict(bg_dict)
658
659        result = {}
660
661        with torch.no_grad():
662            for i, fname in enumerate(self.feature_names):
663                spec = self.features[fname]
664
665                if spec == "continuous":
666                    vals = bg_dict[fname].float()
667                    lo, hi = vals.min().item(), vals.max().item()
668                    grid = torch.linspace(lo, hi, n_grid, device=self._device)
669                elif isinstance(spec, int):
670                    grid = torch.arange(spec, device=self._device)
671                else:
672                    continue
673
674                # Build batch: vary feature i, fix others to background mean
675                n_grid_pts = grid.shape[0]
676                x_eval = {}
677                for fname2 in self.feature_names:
678                    spec2 = self.features[fname2]
679                    if fname2 == fname:
680                        x_eval[fname2] = grid.long() if isinstance(spec, int) else grid
681                    else:
682                        if spec2 == "continuous":
683                            mean_val = bg_dict[fname2].float().mean()
684                            x_eval[fname2] = mean_val.expand(n_grid_pts)
685                        else:
686                            mode_val = bg_dict[fname2].long().mode().values
687                            x_eval[fname2] = mode_val.expand(n_grid_pts)
688
689                # Get diagonal pair contribution (j=i, k=i)
690                phi_j = self.feature_embeddings.embed_feature(fname, x_eval[fname])
691                token = self.interaction_tokens.get_token(i, i)
692                raw = self.shared_net(phi_j, phi_j, token)
693                h = centered_hard_sigmoid(raw).squeeze(-1)
694                pair_idx = self.interaction_tokens._pair_to_idx[(i, i)]
695                w = self.output_weights[pair_idx]
696                effect = (w * h).cpu().numpy()
697
698                result[fname] = (grid.cpu().numpy(), effect)
699
700        return result
701
702    def interaction_surfaces(
703        self,
704        X_background,
705        n_grid: int = 30,
706        pairs: Optional[List[Tuple[str, str]]] = None,
707    ) -> Dict[Tuple[str, str], Dict]:
708        """
709        Compute 2D interaction surfaces for feature pairs.
710
711        For pair (j, k): evaluates w_{jk} * h_{jk}(x_j, x_k) over a 2D grid.
712        Main effect contributions (j=j and k=k pairs) are NOT subtracted here —
713        this shows the raw interaction term. Use pair_contributions() for
714        full decomposition.
715
716        Args:
717            X_background: Background data for range estimation.
718            n_grid: Grid resolution per axis (n_grid x n_grid for continuous).
719            pairs: List of (feature_j, feature_k) pairs to compute. If None,
720                computes all off-diagonal interaction pairs.
721
722        Returns:
723            Dict mapping (fname_j, fname_k) -> {
724                'grid_j': array of shape (n_grid,),
725                'grid_k': array of shape (n_grid,) or (n_cats,),
726                'surface': array of shape (n_grid_j, n_grid_k),
727            }
728        """
729        if not self._is_fitted:
730            raise RuntimeError("Call fit() before interaction_surfaces().")
731
732        self.eval()
733        bg_dict = self._prepare_features(X_background)
734        bg_dict = self._to_device_dict(bg_dict)
735
736        # All off-diagonal pairs if not specified
737        if pairs is None:
738            all_pairs = self.interaction_tokens.pair_indices()
739            pairs = [
740                (self.feature_names[j], self.feature_names[k])
741                for j, k in all_pairs
742                if j != k
743            ]
744
745        result = {}
746
747        with torch.no_grad():
748            for (fname_j, fname_k) in pairs:
749                j = self.feature_names.index(fname_j)
750                k = self.feature_names.index(fname_k)
751
752                spec_j = self.features[fname_j]
753                spec_k = self.features[fname_k]
754
755                if spec_j == "continuous":
756                    vals_j = bg_dict[fname_j].float()
757                    grid_j = torch.linspace(vals_j.min(), vals_j.max(), n_grid, device=self._device)
758                else:
759                    grid_j = torch.arange(spec_j, device=self._device)
760
761                if spec_k == "continuous":
762                    vals_k = bg_dict[fname_k].float()
763                    grid_k = torch.linspace(vals_k.min(), vals_k.max(), n_grid, device=self._device)
764                else:
765                    grid_k = torch.arange(spec_k, device=self._device)
766
767                nj, nk = grid_j.shape[0], grid_k.shape[0]
768
769                # Meshgrid
770                gj = grid_j.repeat_interleave(nk)  # (nj * nk,)
771                gk = grid_k.repeat(nj)              # (nj * nk,)
772
773                phi_j = self.feature_embeddings.embed_feature(
774                    fname_j, gj.long() if isinstance(spec_j, int) else gj
775                )
776                phi_k = self.feature_embeddings.embed_feature(
777                    fname_k, gk.long() if isinstance(spec_k, int) else gk
778                )
779
780                token = self.interaction_tokens.get_token(j, k)
781                raw = self.shared_net(phi_j, phi_k, token)
782                h = centered_hard_sigmoid(raw).squeeze(-1)
783                pair_idx = self.interaction_tokens._pair_to_idx[(j, k)]
784                w = self.output_weights[pair_idx]
785                surface = (w * h).reshape(nj, nk).cpu().numpy()
786
787                result[(fname_j, fname_k)] = {
788                    "grid_j": grid_j.cpu().numpy(),
789                    "grid_k": grid_k.cpu().numpy(),
790                    "surface": surface,
791                }
792
793        return result
794
795    def shapley_values(
796        self,
797        X,
798        X_background,
799        n_background: int = 100,
800    ) -> Dict[str, np.ndarray]:
801        """
802        Compute exact Shapley values using the pairwise additive structure.
803
804        Cost: 2(q+1) forward passes per test sample per background sample.
805        For q=9, n_background=100: ~2000 forward passes per test sample.
806        This is exact — no sampling approximation.
807
808        Args:
809            X: Test data.
810            X_background: Background data for baseline distribution.
811            n_background: Number of background samples to use.
812
813        Returns:
814            Dict mapping feature_name -> (n_test,) array of SHAP values.
815            Values are on the linear predictor scale.
816        """
817        from insurance_gam.pin.shapley import exact_shapley_values
818
819        if not self._is_fitted:
820            raise RuntimeError("Call fit() before shapley_values().")
821
822        x_dict = self._prepare_features(X)
823        x_dict = self._to_device_dict(x_dict)
824        bg_dict = self._prepare_features(X_background)
825        bg_dict = self._to_device_dict(bg_dict)
826
827        return exact_shapley_values(self, x_dict, bg_dict, n_background)
828
829    def interaction_weights(self) -> Dict[Tuple[str, str], float]:
830        """
831        Return output weights w_{jk} for all pairs as a dict.
832
833        Large absolute weights indicate important interactions or main effects.
834        Note: weights are on the scale of the linear predictor; compare
835        w_{jk} * h_{jk} range rather than w_{jk} alone for fair comparison.
836
837        Returns:
838            Dict mapping (fname_j, fname_k) -> float.
839        """
840        pairs = self.interaction_tokens.pair_indices()
841        weights = self.output_weights.detach().cpu().numpy()
842        return {
843            (self.feature_names[j], self.feature_names[k]): float(weights[idx])
844            for idx, (j, k) in enumerate(pairs)
845        }
846
847    def count_parameters(self) -> int:
848        """Count trainable parameters."""
849        return sum(p.numel() for p in self.parameters() if p.requires_grad)

Single Tree-like Pairwise Interaction Network.

Prediction:

f_PIN(x) = exp( sum_{j<=k} w_{jk} * h_{jk}(x) + b )

Arguments:
  • features: Dict mapping feature name to spec. Spec is 'continuous' or an int (number of categories). Order matters — features are indexed by position.
  • embedding_dim: Feature embedding dimension d (default 10).
  • hidden_dim: Hidden width for continuous embedding FNNs d' (default 20).
  • token_dim: Interaction token dimension d0 (default 10).
  • shared_dims: (d1, d2) widths for shared interaction network (default [30, 20]).
  • loss: Loss name — 'poisson', 'gamma', or 'tweedie'.
  • tweedie_p: Tweedie power (only used when loss='tweedie').
  • lr: Adam learning rate (default 0.001).
  • batch_size: Mini-batch size (default 128).
  • max_epochs: Maximum training epochs (default 500).
  • patience: Early stopping patience in epochs (default 20).
  • lr_patience: ReduceLROnPlateau patience (default 5).
  • lr_factor: ReduceLROnPlateau reduction factor (default 0.9).
  • val_fraction: Fraction of training data for validation if X_val not given (default 0.1).
  • device: Torch device string, or None to auto-detect.
  • random_seed: Seed for reproducibility.
Examples:
>>> model = PINModel(
...     features={"age": "continuous", "area": 5},
...     loss="poisson",
... )
PINModel( features: Dict[str, Union[str, int]], embedding_dim: int = 10, hidden_dim: int = 20, token_dim: int = 10, shared_dims: Tuple[int, int] = (30, 20), loss: str = 'poisson', tweedie_p: float = 1.5, lr: float = 0.001, batch_size: int = 128, max_epochs: int = 500, patience: int = 20, lr_patience: int = 5, lr_factor: float = 0.9, val_fraction: float = 0.1, device: Optional[str] = None, random_seed: int = 42)
 95    def __init__(
 96        self,
 97        features: FeatureSpec,
 98        embedding_dim: int = 10,
 99        hidden_dim: int = 20,
100        token_dim: int = 10,
101        shared_dims: Tuple[int, int] = (30, 20),
102        loss: str = "poisson",
103        tweedie_p: float = 1.5,
104        lr: float = 0.001,
105        batch_size: int = 128,
106        max_epochs: int = 500,
107        patience: int = 20,
108        lr_patience: int = 5,
109        lr_factor: float = 0.9,
110        val_fraction: float = 0.1,
111        device: Optional[str] = None,
112        random_seed: int = 42,
113    ) -> None:
114        super().__init__()
115
116        self.features = dict(features)
117        self.feature_names: List[str] = list(features.keys())
118        self.q = len(self.feature_names)
119        self.embedding_dim = embedding_dim
120        self.hidden_dim = hidden_dim
121        self.token_dim = token_dim
122        self.shared_dims = tuple(shared_dims)
123        self.loss_name = loss
124        self.tweedie_p = tweedie_p
125        self.lr = lr
126        self.batch_size = batch_size
127        self.max_epochs = max_epochs
128        self.patience = patience
129        self.lr_patience = lr_patience
130        self.lr_factor = lr_factor
131        self.val_fraction = val_fraction
132        self.random_seed = random_seed
133
134        # Resolve device
135        if device is None:
136            self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137        else:
138            self._device = torch.device(device)
139
140        # --- Sub-modules ---
141        self.feature_embeddings = FeatureEmbeddings(
142            features=features,
143            embedding_dim=embedding_dim,
144            hidden_dim=hidden_dim,
145        )
146        self.interaction_tokens = InteractionTokens(
147            n_features=self.q,
148            token_dim=token_dim,
149        )
150        self.shared_net = SharedInteractionNet(
151            embedding_dim=embedding_dim,
152            token_dim=token_dim,
153            layer1_dim=shared_dims[0],
154            layer2_dim=shared_dims[1],
155        )
156
157        n_pairs = self.q * (self.q + 1) // 2
158        # Output weights w_{jk} and bias b — linear combination of pair terms
159        self.output_weights = nn.Parameter(torch.randn(n_pairs) * 0.01)
160        self.output_bias = nn.Parameter(torch.zeros(1))
161
162        # Loss function
163        loss_kwargs = {"p": tweedie_p} if loss == "tweedie" else {}
164        self._loss_fn = get_loss(loss, **loss_kwargs)
165
166        # Weight initialisation: smaller scales for numerical stability
167        self._init_weights()
168
169        # Training state
170        self._is_fitted = False
171        self.train_history: Dict[str, List[float]] = {"train_loss": [], "val_loss": []}
172
173        # Centering offsets (set post-hoc after fitting)
174        # h_{jk}^centered(x) = h_{jk}(x) - mean_train[h_{jk}]
175        # We store mean_train[w_{jk} * h_{jk}] per pair for efficiency
176        self._pair_means: Optional[torch.Tensor] = None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

features
feature_names: List[str]
q
embedding_dim
hidden_dim
token_dim
shared_dims
loss_name
tweedie_p
lr
batch_size
max_epochs
patience
lr_patience
lr_factor
val_fraction
random_seed
feature_embeddings
interaction_tokens
shared_net
output_weights
output_bias
train_history: Dict[str, List[float]]
def forward( self, x_dict: Dict[str, torch.Tensor], exposure: Optional[torch.Tensor] = None) -> torch.Tensor:
247    def forward(
248        self,
249        x_dict: Dict[str, torch.Tensor],
250        exposure: Optional[torch.Tensor] = None,
251    ) -> torch.Tensor:
252        """
253        Compute predictions f_PIN(x) = exp(linear_predictor) * exposure.
254
255        When exposure is provided, the raw model output is frequency (claims per
256        year) and multiplying by exposure gives expected claim count. For
257        frequency models, typically you'd pass exposure=None and let the caller
258        multiply; this method supports both modes.
259
260        Args:
261            x_dict: Feature tensors. Dict of feature_name -> (batch,) tensor.
262            exposure: Optional per-sample exposure, shape (batch,).
263
264        Returns:
265            Predicted frequency, shape (batch,).
266        """
267        eta = self._compute_linear_predictor(x_dict)
268        # Clamp to avoid exp overflow (GLM link stabilisation)
269        eta = torch.clamp(eta, min=-20.0, max=20.0)
270        mu = torch.exp(eta)
271        if exposure is not None:
272            mu = mu * exposure
273        return mu

Compute predictions f_PIN(x) = exp(linear_predictor) * exposure.

When exposure is provided, the raw model output is frequency (claims per year) and multiplying by exposure gives expected claim count. For frequency models, typically you'd pass exposure=None and let the caller multiply; this method supports both modes.

Arguments:
  • x_dict: Feature tensors. Dict of feature_name -> (batch,) tensor.
  • exposure: Optional per-sample exposure, shape (batch,).
Returns:

Predicted frequency, shape (batch,).

def fit( self, X_train, y_train: numpy.ndarray, exposure: Optional[numpy.ndarray] = None, X_val=None, y_val: Optional[numpy.ndarray] = None, exposure_val: Optional[numpy.ndarray] = None, verbose: bool = True) -> PINModel:
339    def fit(
340        self,
341        X_train,
342        y_train: np.ndarray,
343        exposure: Optional[np.ndarray] = None,
344        X_val=None,
345        y_val: Optional[np.ndarray] = None,
346        exposure_val: Optional[np.ndarray] = None,
347        verbose: bool = True,
348    ) -> "PINModel":
349        """
350        Fit the model.
351
352        Args:
353            X_train: Training features. Dict, polars.DataFrame, or pandas.DataFrame.
354            y_train: Observed frequency (claims / exposure), shape (n,).
355            exposure: Per-sample exposure (years at risk), shape (n,).
356            X_val: Validation features (optional). If None, 10% of training data
357                is reserved.
358            y_val: Validation targets.
359            exposure_val: Validation exposure.
360            verbose: Print training progress.
361
362        Returns:
363            self (for chaining).
364        """
365        torch.manual_seed(self.random_seed)
366        np.random.seed(self.random_seed)
367
368        self.to(self._device)
369
370        # Prepare training data
371        x_dict = self._prepare_features(X_train)
372        x_dict = self._to_device_dict(x_dict)
373        y_t = _to_tensor(y_train).to(self._device)
374        exp_t = (
375            _to_tensor(exposure).to(self._device)
376            if exposure is not None
377            else torch.ones_like(y_t)
378        )
379
380        n = y_t.shape[0]
381
382        # Build or reserve validation set (must happen before bias init so we
383        # use training-only data for the mean frequency estimate).
384        if X_val is not None:
385            x_val_dict = self._prepare_features(X_val)
386            x_val_dict = self._to_device_dict(x_val_dict)
387            y_val_t = _to_tensor(y_val).to(self._device)
388            exp_val_t = (
389                _to_tensor(exposure_val).to(self._device)
390                if exposure_val is not None
391                else torch.ones_like(y_val_t)
392            )
393        else:
394            # Reserve val_fraction from training
395            val_size = max(1, int(n * self.val_fraction))
396            perm = torch.randperm(n, device=self._device)
397            val_idx = perm[:val_size]
398            train_idx = perm[val_size:]
399
400            x_val_dict = {k: v[val_idx] for k, v in x_dict.items()}
401            y_val_t = y_t[val_idx]
402            exp_val_t = exp_t[val_idx]
403
404            x_dict = {k: v[train_idx] for k, v in x_dict.items()}
405            y_t = y_t[train_idx]
406            exp_t = exp_t[train_idx]
407
408        n_train = y_t.shape[0]
409
410        # Initialise bias to log(mean_frequency) of the TRAINING split only.
411        # Computing this from the full dataset (including validation) would
412        # be a mild form of data leakage — small in practice but wrong in
413        # principle. We compute it here, after the split.
414        with torch.no_grad():
415            mean_freq = y_t.mean().clamp(min=1e-8)
416            self.output_bias.fill_(torch.log(mean_freq).item())
417
418        optimizer = optim.Adam(self.parameters(), lr=self.lr)
419        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
420            optimizer,
421            mode="min",
422            factor=self.lr_factor,
423            patience=self.lr_patience,
424        )
425
426        best_val_loss = float("inf")
427        best_state = None
428        epochs_no_improve = 0
429
430        self.train_history = {"train_loss": [], "val_loss": []}
431
432        for epoch in range(self.max_epochs):
433            self.train()
434            # Shuffle training data
435            perm = torch.randperm(n_train, device=self._device)
436
437            epoch_loss = 0.0
438            n_batches = 0
439
440            for start in range(0, n_train, self.batch_size):
441                end = min(start + self.batch_size, n_train)
442                batch_idx = perm[start:end]
443
444                x_batch = {k: v[batch_idx] for k, v in x_dict.items()}
445                y_batch = y_t[batch_idx]
446                exp_batch = exp_t[batch_idx]
447
448                optimizer.zero_grad()
449                mu = self.forward(x_batch)
450                loss = self._loss_fn(mu, y_batch, exp_batch)
451                if torch.isnan(loss):
452                    continue  # skip NaN batches
453                loss.backward()
454                # Gradient clipping prevents exploding gradients in early training
455                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
456                # Zero NaN gradients before step
457                for p in self.parameters():
458                    if p.grad is not None and torch.isnan(p.grad).any():
459                        p.grad.zero_()
460                optimizer.step()
461                # Restore any NaN params to their pre-step values
462                for p in self.parameters():
463                    if torch.isnan(p).any():
464                        torch.nan_to_num_(p, nan=0.0)
465
466                epoch_loss += loss.item()
467                n_batches += 1
468
469            train_loss = epoch_loss / n_batches
470
471            # Validation
472            self.eval()
473            with torch.no_grad():
474                mu_val = self.forward(x_val_dict)
475                val_loss = self._loss_fn(mu_val, y_val_t, exp_val_t).item()
476
477            scheduler.step(val_loss)
478
479            self.train_history["train_loss"].append(train_loss)
480            self.train_history["val_loss"].append(val_loss)
481
482            if verbose and (epoch % 50 == 0 or epoch < 10):
483                lr_now = optimizer.param_groups[0]["lr"]
484                print(
485                    f"Epoch {epoch:4d} | train={train_loss:.6f} | val={val_loss:.6f} | lr={lr_now:.2e}"
486                )
487
488            # Early stopping
489            if val_loss < best_val_loss:
490                best_val_loss = val_loss
491                best_state = copy.deepcopy(self.state_dict())
492                epochs_no_improve = 0
493            else:
494                epochs_no_improve += 1
495                if epochs_no_improve >= self.patience:
496                    if verbose:
497                        print(f"Early stopping at epoch {epoch}. Best val={best_val_loss:.6f}")
498                    break
499
500        # Restore best weights
501        if best_state is not None:
502            self.load_state_dict(best_state)
503
504        # Post-hoc centering: compute pair means on training data
505        self._compute_pair_centering(x_dict)
506
507        self._is_fitted = True
508        return self

Fit the model.

Arguments:
  • X_train: Training features. Dict, polars.DataFrame, or pandas.DataFrame.
  • y_train: Observed frequency (claims / exposure), shape (n,).
  • exposure: Per-sample exposure (years at risk), shape (n,).
  • X_val: Validation features (optional). If None, 10% of training data is reserved.
  • y_val: Validation targets.
  • exposure_val: Validation exposure.
  • verbose: Print training progress.
Returns:

self (for chaining).

def predict(self, X, exposure: Optional[numpy.ndarray] = None) -> numpy.ndarray:
556    def predict(
557        self,
558        X,
559        exposure: Optional[np.ndarray] = None,
560    ) -> np.ndarray:
561        """
562        Predict frequency (or expected claims if exposure given).
563
564        Args:
565            X: Features. Dict, polars.DataFrame, or pandas.DataFrame.
566            exposure: Per-sample exposure.
567
568        Returns:
569            Predictions as numpy array, shape (n,).
570        """
571        if not self._is_fitted:
572            raise RuntimeError("Call fit() before predict().")
573
574        self.eval()
575        x_dict = self._prepare_features(X)
576        x_dict = self._to_device_dict(x_dict)
577
578        exp_t = (
579            _to_tensor(exposure).to(self._device)
580            if exposure is not None
581            else None
582        )
583
584        with torch.no_grad():
585            mu = self.forward(x_dict, exposure=exp_t)
586
587        return mu.cpu().numpy()

Predict frequency (or expected claims if exposure given).

Arguments:
  • X: Features. Dict, polars.DataFrame, or pandas.DataFrame.
  • exposure: Per-sample exposure.
Returns:

Predictions as numpy array, shape (n,).

def pair_contributions(self, X) -> Dict[Tuple[str, str], numpy.ndarray]:
593    def pair_contributions(
594        self,
595        X,
596    ) -> Dict[Tuple[str, str], np.ndarray]:
597        """
598        Compute w_{jk} * h_{jk}(x) for each pair and each sample.
599
600        These are the additive components on the linear predictor scale.
601        Useful for understanding which pairs drive predictions.
602
603        Args:
604            X: Features.
605
606        Returns:
607            Dict mapping (fname_j, fname_k) -> array of shape (n,).
608        """
609        if not self._is_fitted:
610            raise RuntimeError("Call fit() before pair_contributions().")
611
612        self.eval()
613        x_dict = self._prepare_features(X)
614        x_dict = self._to_device_dict(x_dict)
615        pairs = self.interaction_tokens.pair_indices()
616        result = {}
617
618        with torch.no_grad():
619            embeddings = self.feature_embeddings.embed_all(x_dict)
620
621            for pair_idx, (j, k) in enumerate(pairs):
622                fname_j = self.feature_names[j]
623                fname_k = self.feature_names[k]
624                phi_j = embeddings[fname_j]
625                phi_k = embeddings[fname_k]
626                token = self.interaction_tokens.get_token(j, k)
627                raw = self.shared_net(phi_j, phi_k, token)
628                h = centered_hard_sigmoid(raw).squeeze(-1)
629                w = self.output_weights[pair_idx]
630                contrib = (w * h)
631                if self._pair_means is not None:
632                    contrib = contrib - self._pair_means[pair_idx]
633                result[(fname_j, fname_k)] = contrib.cpu().numpy()
634
635        return result

Compute w_{jk} * h_{jk}(x) for each pair and each sample.

These are the additive components on the linear predictor scale. Useful for understanding which pairs drive predictions.

Arguments:
  • X: Features.
Returns:

Dict mapping (fname_j, fname_k) -> array of shape (n,).

def main_effects( self, X_background, n_grid: int = 100) -> Dict[str, Tuple[numpy.ndarray, numpy.ndarray]]:
637    def main_effects(
638        self,
639        X_background,
640        n_grid: int = 100,
641    ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
642        """
643        Compute main effect curves for each feature.
644
645        For continuous features: evaluates h_{jj} over an evenly-spaced grid
646        while fixing all other features to background means.
647
648        Returns:
649            Dict mapping feature_name -> (grid_values, effect_values).
650            For categoricals: grid_values are integer category codes.
651        """
652        if not self._is_fitted:
653            raise RuntimeError("Call fit() before main_effects().")
654
655        self.eval()
656        bg_dict = self._prepare_features(X_background)
657        bg_dict = self._to_device_dict(bg_dict)
658
659        result = {}
660
661        with torch.no_grad():
662            for i, fname in enumerate(self.feature_names):
663                spec = self.features[fname]
664
665                if spec == "continuous":
666                    vals = bg_dict[fname].float()
667                    lo, hi = vals.min().item(), vals.max().item()
668                    grid = torch.linspace(lo, hi, n_grid, device=self._device)
669                elif isinstance(spec, int):
670                    grid = torch.arange(spec, device=self._device)
671                else:
672                    continue
673
674                # Build batch: vary feature i, fix others to background mean
675                n_grid_pts = grid.shape[0]
676                x_eval = {}
677                for fname2 in self.feature_names:
678                    spec2 = self.features[fname2]
679                    if fname2 == fname:
680                        x_eval[fname2] = grid.long() if isinstance(spec, int) else grid
681                    else:
682                        if spec2 == "continuous":
683                            mean_val = bg_dict[fname2].float().mean()
684                            x_eval[fname2] = mean_val.expand(n_grid_pts)
685                        else:
686                            mode_val = bg_dict[fname2].long().mode().values
687                            x_eval[fname2] = mode_val.expand(n_grid_pts)
688
689                # Get diagonal pair contribution (j=i, k=i)
690                phi_j = self.feature_embeddings.embed_feature(fname, x_eval[fname])
691                token = self.interaction_tokens.get_token(i, i)
692                raw = self.shared_net(phi_j, phi_j, token)
693                h = centered_hard_sigmoid(raw).squeeze(-1)
694                pair_idx = self.interaction_tokens._pair_to_idx[(i, i)]
695                w = self.output_weights[pair_idx]
696                effect = (w * h).cpu().numpy()
697
698                result[fname] = (grid.cpu().numpy(), effect)
699
700        return result

Compute main effect curves for each feature.

For continuous features: evaluates h_{jj} over an evenly-spaced grid while fixing all other features to background means.

Returns:

Dict mapping feature_name -> (grid_values, effect_values). For categoricals: grid_values are integer category codes.

def interaction_surfaces( self, X_background, n_grid: int = 30, pairs: Optional[List[Tuple[str, str]]] = None) -> Dict[Tuple[str, str], Dict]:
702    def interaction_surfaces(
703        self,
704        X_background,
705        n_grid: int = 30,
706        pairs: Optional[List[Tuple[str, str]]] = None,
707    ) -> Dict[Tuple[str, str], Dict]:
708        """
709        Compute 2D interaction surfaces for feature pairs.
710
711        For pair (j, k): evaluates w_{jk} * h_{jk}(x_j, x_k) over a 2D grid.
712        Main effect contributions (j=j and k=k pairs) are NOT subtracted here —
713        this shows the raw interaction term. Use pair_contributions() for
714        full decomposition.
715
716        Args:
717            X_background: Background data for range estimation.
718            n_grid: Grid resolution per axis (n_grid x n_grid for continuous).
719            pairs: List of (feature_j, feature_k) pairs to compute. If None,
720                computes all off-diagonal interaction pairs.
721
722        Returns:
723            Dict mapping (fname_j, fname_k) -> {
724                'grid_j': array of shape (n_grid,),
725                'grid_k': array of shape (n_grid,) or (n_cats,),
726                'surface': array of shape (n_grid_j, n_grid_k),
727            }
728        """
729        if not self._is_fitted:
730            raise RuntimeError("Call fit() before interaction_surfaces().")
731
732        self.eval()
733        bg_dict = self._prepare_features(X_background)
734        bg_dict = self._to_device_dict(bg_dict)
735
736        # All off-diagonal pairs if not specified
737        if pairs is None:
738            all_pairs = self.interaction_tokens.pair_indices()
739            pairs = [
740                (self.feature_names[j], self.feature_names[k])
741                for j, k in all_pairs
742                if j != k
743            ]
744
745        result = {}
746
747        with torch.no_grad():
748            for (fname_j, fname_k) in pairs:
749                j = self.feature_names.index(fname_j)
750                k = self.feature_names.index(fname_k)
751
752                spec_j = self.features[fname_j]
753                spec_k = self.features[fname_k]
754
755                if spec_j == "continuous":
756                    vals_j = bg_dict[fname_j].float()
757                    grid_j = torch.linspace(vals_j.min(), vals_j.max(), n_grid, device=self._device)
758                else:
759                    grid_j = torch.arange(spec_j, device=self._device)
760
761                if spec_k == "continuous":
762                    vals_k = bg_dict[fname_k].float()
763                    grid_k = torch.linspace(vals_k.min(), vals_k.max(), n_grid, device=self._device)
764                else:
765                    grid_k = torch.arange(spec_k, device=self._device)
766
767                nj, nk = grid_j.shape[0], grid_k.shape[0]
768
769                # Meshgrid
770                gj = grid_j.repeat_interleave(nk)  # (nj * nk,)
771                gk = grid_k.repeat(nj)              # (nj * nk,)
772
773                phi_j = self.feature_embeddings.embed_feature(
774                    fname_j, gj.long() if isinstance(spec_j, int) else gj
775                )
776                phi_k = self.feature_embeddings.embed_feature(
777                    fname_k, gk.long() if isinstance(spec_k, int) else gk
778                )
779
780                token = self.interaction_tokens.get_token(j, k)
781                raw = self.shared_net(phi_j, phi_k, token)
782                h = centered_hard_sigmoid(raw).squeeze(-1)
783                pair_idx = self.interaction_tokens._pair_to_idx[(j, k)]
784                w = self.output_weights[pair_idx]
785                surface = (w * h).reshape(nj, nk).cpu().numpy()
786
787                result[(fname_j, fname_k)] = {
788                    "grid_j": grid_j.cpu().numpy(),
789                    "grid_k": grid_k.cpu().numpy(),
790                    "surface": surface,
791                }
792
793        return result

Compute 2D interaction surfaces for feature pairs.

For pair (j, k): evaluates w_{jk} * h_{jk}(x_j, x_k) over a 2D grid. Main effect contributions (j=j and k=k pairs) are NOT subtracted here — this shows the raw interaction term. Use pair_contributions() for full decomposition.

Arguments:
  • X_background: Background data for range estimation.
  • n_grid: Grid resolution per axis (n_grid x n_grid for continuous).
  • pairs: List of (feature_j, feature_k) pairs to compute. If None, computes all off-diagonal interaction pairs.
Returns:

Dict mapping (fname_j, fname_k) -> { 'grid_j': array of shape (n_grid,), 'grid_k': array of shape (n_grid,) or (n_cats,), 'surface': array of shape (n_grid_j, n_grid_k), }

def shapley_values( self, X, X_background, n_background: int = 100) -> Dict[str, numpy.ndarray]:
795    def shapley_values(
796        self,
797        X,
798        X_background,
799        n_background: int = 100,
800    ) -> Dict[str, np.ndarray]:
801        """
802        Compute exact Shapley values using the pairwise additive structure.
803
804        Cost: 2(q+1) forward passes per test sample per background sample.
805        For q=9, n_background=100: ~2000 forward passes per test sample.
806        This is exact — no sampling approximation.
807
808        Args:
809            X: Test data.
810            X_background: Background data for baseline distribution.
811            n_background: Number of background samples to use.
812
813        Returns:
814            Dict mapping feature_name -> (n_test,) array of SHAP values.
815            Values are on the linear predictor scale.
816        """
817        from insurance_gam.pin.shapley import exact_shapley_values
818
819        if not self._is_fitted:
820            raise RuntimeError("Call fit() before shapley_values().")
821
822        x_dict = self._prepare_features(X)
823        x_dict = self._to_device_dict(x_dict)
824        bg_dict = self._prepare_features(X_background)
825        bg_dict = self._to_device_dict(bg_dict)
826
827        return exact_shapley_values(self, x_dict, bg_dict, n_background)

Compute exact Shapley values using the pairwise additive structure.

Cost: 2(q+1) forward passes per test sample per background sample. For q=9, n_background=100: ~2000 forward passes per test sample. This is exact — no sampling approximation.

Arguments:
  • X: Test data.
  • X_background: Background data for baseline distribution.
  • n_background: Number of background samples to use.
Returns:

Dict mapping feature_name -> (n_test,) array of SHAP values. Values are on the linear predictor scale.

def interaction_weights(self) -> Dict[Tuple[str, str], float]:
829    def interaction_weights(self) -> Dict[Tuple[str, str], float]:
830        """
831        Return output weights w_{jk} for all pairs as a dict.
832
833        Large absolute weights indicate important interactions or main effects.
834        Note: weights are on the scale of the linear predictor; compare
835        w_{jk} * h_{jk} range rather than w_{jk} alone for fair comparison.
836
837        Returns:
838            Dict mapping (fname_j, fname_k) -> float.
839        """
840        pairs = self.interaction_tokens.pair_indices()
841        weights = self.output_weights.detach().cpu().numpy()
842        return {
843            (self.feature_names[j], self.feature_names[k]): float(weights[idx])
844            for idx, (j, k) in enumerate(pairs)
845        }

Return output weights w_{jk} for all pairs as a dict.

Large absolute weights indicate important interactions or main effects. Note: weights are on the scale of the linear predictor; compare w_{jk} * h_{jk} range rather than w_{jk} alone for fair comparison.

Returns:

Dict mapping (fname_j, fname_k) -> float.

def count_parameters(self) -> int:
847    def count_parameters(self) -> int:
848        """Count trainable parameters."""
849        return sum(p.numel() for p in self.parameters() if p.requires_grad)

Count trainable parameters.

class PINEnsemble:
852class PINEnsemble:
853    """
854    Ensemble of PINModel instances trained with different random seeds.
855
856    Ensemble averaging is the primary regularisation strategy in the PIN paper.
857    With n=10 runs, ensemble PIN achieves the best published result on French MTPL.
858
859    Args:
860        n_models: Number of models in the ensemble.
861        **kwargs: Passed to each PINModel constructor.
862    """
863
864    def __init__(self, n_models: int = 10, **kwargs) -> None:
865        self.n_models = n_models
866        self.kwargs = kwargs
867        self.models: List[PINModel] = []
868        self._is_fitted = False
869
870    def fit(
871        self,
872        X_train,
873        y_train: np.ndarray,
874        exposure: Optional[np.ndarray] = None,
875        X_val=None,
876        y_val: Optional[np.ndarray] = None,
877        exposure_val: Optional[np.ndarray] = None,
878        verbose: bool = False,
879    ) -> "PINEnsemble":
880        """
881        Fit all models with different seeds.
882
883        Args:
884            X_train: Training features.
885            y_train: Observed frequency.
886            exposure: Training exposure.
887            X_val: Validation features (optional).
888            y_val: Validation targets.
889            exposure_val: Validation exposure.
890            verbose: Print per-model progress.
891
892        Returns:
893            self.
894        """
895        self.models = []
896        for i in range(self.n_models):
897            seed = self.kwargs.get("random_seed", 42) + i
898            kw = {**self.kwargs, "random_seed": seed}
899            model = PINModel(**kw)
900            print(f"[Ensemble] Fitting model {i+1}/{self.n_models} (seed={seed})")
901            model.fit(
902                X_train,
903                y_train,
904                exposure=exposure,
905                X_val=X_val,
906                y_val=y_val,
907                exposure_val=exposure_val,
908                verbose=verbose,
909            )
910            self.models.append(model)
911
912        self._is_fitted = True
913        return self
914
915    def predict(
916        self,
917        X,
918        exposure: Optional[np.ndarray] = None,
919    ) -> np.ndarray:
920        """
921        Average predictions across all models.
922
923        Args:
924            X: Features.
925            exposure: Per-sample exposure.
926
927        Returns:
928            Ensemble mean prediction, shape (n,).
929        """
930        if not self._is_fitted:
931            raise RuntimeError("Call fit() before predict().")
932
933        preds = np.stack([m.predict(X, exposure=exposure) for m in self.models], axis=0)
934        return preds.mean(axis=0)
935
936    def predict_std(
937        self,
938        X,
939        exposure: Optional[np.ndarray] = None,
940    ) -> np.ndarray:
941        """
942        Standard deviation of predictions across models (epistemic uncertainty).
943
944        Returns:
945            Shape (n,).
946        """
947        if not self._is_fitted:
948            raise RuntimeError("Call fit() before predict_std().")
949
950        preds = np.stack([m.predict(X, exposure=exposure) for m in self.models], axis=0)
951        return preds.std(axis=0)
952
953    def shapley_values(
954        self,
955        X,
956        X_background,
957        n_background: int = 100,
958    ) -> Dict[str, np.ndarray]:
959        """
960        Average Shapley values across ensemble members.
961
962        Returns:
963            Dict mapping feature_name -> (n_test,) array.
964        """
965        if not self._is_fitted:
966            raise RuntimeError("Call fit() before shapley_values().")
967
968        all_shaps = [m.shapley_values(X, X_background, n_background) for m in self.models]
969        feature_names = list(all_shaps[0].keys())
970
971        return {
972            fname: np.stack([s[fname] for s in all_shaps], axis=0).mean(axis=0)
973            for fname in feature_names
974        }
975
976    def interaction_weights(self) -> Dict[Tuple[str, str], float]:
977        """
978        Mean absolute interaction weights across ensemble.
979
980        Returns:
981            Dict mapping (fname_j, fname_k) -> mean |w_{jk}|.
982        """
983        from collections import defaultdict
984        sums: Dict = defaultdict(float)
985        for model in self.models:
986            for pair, w in model.interaction_weights().items():
987                sums[pair] += abs(w)
988        return {pair: v / self.n_models for pair, v in sums.items()}

Ensemble of PINModel instances trained with different random seeds.

Ensemble averaging is the primary regularisation strategy in the PIN paper. With n=10 runs, ensemble PIN achieves the best published result on French MTPL.

Arguments:
  • n_models: Number of models in the ensemble.
  • **kwargs: Passed to each PINModel constructor.
PINEnsemble(n_models: int = 10, **kwargs)
864    def __init__(self, n_models: int = 10, **kwargs) -> None:
865        self.n_models = n_models
866        self.kwargs = kwargs
867        self.models: List[PINModel] = []
868        self._is_fitted = False
n_models
kwargs
models: List[PINModel]
def fit( self, X_train, y_train: numpy.ndarray, exposure: Optional[numpy.ndarray] = None, X_val=None, y_val: Optional[numpy.ndarray] = None, exposure_val: Optional[numpy.ndarray] = None, verbose: bool = False) -> PINEnsemble:
870    def fit(
871        self,
872        X_train,
873        y_train: np.ndarray,
874        exposure: Optional[np.ndarray] = None,
875        X_val=None,
876        y_val: Optional[np.ndarray] = None,
877        exposure_val: Optional[np.ndarray] = None,
878        verbose: bool = False,
879    ) -> "PINEnsemble":
880        """
881        Fit all models with different seeds.
882
883        Args:
884            X_train: Training features.
885            y_train: Observed frequency.
886            exposure: Training exposure.
887            X_val: Validation features (optional).
888            y_val: Validation targets.
889            exposure_val: Validation exposure.
890            verbose: Print per-model progress.
891
892        Returns:
893            self.
894        """
895        self.models = []
896        for i in range(self.n_models):
897            seed = self.kwargs.get("random_seed", 42) + i
898            kw = {**self.kwargs, "random_seed": seed}
899            model = PINModel(**kw)
900            print(f"[Ensemble] Fitting model {i+1}/{self.n_models} (seed={seed})")
901            model.fit(
902                X_train,
903                y_train,
904                exposure=exposure,
905                X_val=X_val,
906                y_val=y_val,
907                exposure_val=exposure_val,
908                verbose=verbose,
909            )
910            self.models.append(model)
911
912        self._is_fitted = True
913        return self

Fit all models with different seeds.

Arguments:
  • X_train: Training features.
  • y_train: Observed frequency.
  • exposure: Training exposure.
  • X_val: Validation features (optional).
  • y_val: Validation targets.
  • exposure_val: Validation exposure.
  • verbose: Print per-model progress.
Returns:

self.

def predict(self, X, exposure: Optional[numpy.ndarray] = None) -> numpy.ndarray:
915    def predict(
916        self,
917        X,
918        exposure: Optional[np.ndarray] = None,
919    ) -> np.ndarray:
920        """
921        Average predictions across all models.
922
923        Args:
924            X: Features.
925            exposure: Per-sample exposure.
926
927        Returns:
928            Ensemble mean prediction, shape (n,).
929        """
930        if not self._is_fitted:
931            raise RuntimeError("Call fit() before predict().")
932
933        preds = np.stack([m.predict(X, exposure=exposure) for m in self.models], axis=0)
934        return preds.mean(axis=0)

Average predictions across all models.

Arguments:
  • X: Features.
  • exposure: Per-sample exposure.
Returns:

Ensemble mean prediction, shape (n,).

def predict_std(self, X, exposure: Optional[numpy.ndarray] = None) -> numpy.ndarray:
936    def predict_std(
937        self,
938        X,
939        exposure: Optional[np.ndarray] = None,
940    ) -> np.ndarray:
941        """
942        Standard deviation of predictions across models (epistemic uncertainty).
943
944        Returns:
945            Shape (n,).
946        """
947        if not self._is_fitted:
948            raise RuntimeError("Call fit() before predict_std().")
949
950        preds = np.stack([m.predict(X, exposure=exposure) for m in self.models], axis=0)
951        return preds.std(axis=0)

Standard deviation of predictions across models (epistemic uncertainty).

Returns:

Shape (n,).

def shapley_values( self, X, X_background, n_background: int = 100) -> Dict[str, numpy.ndarray]:
953    def shapley_values(
954        self,
955        X,
956        X_background,
957        n_background: int = 100,
958    ) -> Dict[str, np.ndarray]:
959        """
960        Average Shapley values across ensemble members.
961
962        Returns:
963            Dict mapping feature_name -> (n_test,) array.
964        """
965        if not self._is_fitted:
966            raise RuntimeError("Call fit() before shapley_values().")
967
968        all_shaps = [m.shapley_values(X, X_background, n_background) for m in self.models]
969        feature_names = list(all_shaps[0].keys())
970
971        return {
972            fname: np.stack([s[fname] for s in all_shaps], axis=0).mean(axis=0)
973            for fname in feature_names
974        }

Average Shapley values across ensemble members.

Returns:

Dict mapping feature_name -> (n_test,) array.

def interaction_weights(self) -> Dict[Tuple[str, str], float]:
976    def interaction_weights(self) -> Dict[Tuple[str, str], float]:
977        """
978        Mean absolute interaction weights across ensemble.
979
980        Returns:
981            Dict mapping (fname_j, fname_k) -> mean |w_{jk}|.
982        """
983        from collections import defaultdict
984        sums: Dict = defaultdict(float)
985        for model in self.models:
986            for pair, w in model.interaction_weights().items():
987                sums[pair] += abs(w)
988        return {pair: v / self.n_models for pair, v in sums.items()}

Mean absolute interaction weights across ensemble.

Returns:

Dict mapping (fname_j, fname_k) -> mean |w_{jk}|.

class PINDiagnostics:
 26class PINDiagnostics:
 27    """
 28    Visualisation tools for a fitted PINModel.
 29
 30    Args:
 31        model: Fitted PINModel instance.
 32    """
 33
 34    def __init__(self, model) -> None:
 35        self.model = model
 36
 37    def interaction_heatmap(
 38        self,
 39        figsize: Tuple[int, int] = (8, 7),
 40        cmap: str = "RdBu_r",
 41        title: str = "PIN Interaction Weights |w_{jk}|",
 42        ax=None,
 43    ):
 44        """
 45        Heatmap of interaction weight magnitudes.
 46
 47        Rows and columns are features. The upper triangle (j<k) shows
 48        interactions; the diagonal shows main effects.
 49
 50        Large values indicate pairs that contribute more variance to the
 51        linear predictor. Note: weight magnitude alone doesn't measure
 52        importance — the range of h_{jk} also matters. Use
 53        weighted_importance() for a more meaningful ranking.
 54
 55        Args:
 56            figsize: Figure size.
 57            cmap: Colormap.
 58            title: Plot title.
 59            ax: Existing axes to draw on (optional).
 60
 61        Returns:
 62            (fig, ax) tuple.
 63        """
 64        plt = _get_plt()
 65        weights = self.model.interaction_weights()
 66        feature_names = self.model.feature_names
 67        q = len(feature_names)
 68
 69        matrix = np.zeros((q, q))
 70        for (fname_j, fname_k), w in weights.items():
 71            j = feature_names.index(fname_j)
 72            k = feature_names.index(fname_k)
 73            matrix[j, k] = abs(w)
 74            if j != k:
 75                matrix[k, j] = abs(w)
 76
 77        if ax is None:
 78            fig, ax = plt.subplots(figsize=figsize)
 79        else:
 80            fig = ax.get_figure()
 81
 82        im = ax.imshow(matrix, cmap=cmap, aspect="auto")
 83        plt.colorbar(im, ax=ax, label="|w_{jk}|")
 84
 85        ax.set_xticks(range(q))
 86        ax.set_yticks(range(q))
 87        ax.set_xticklabels(feature_names, rotation=45, ha="right")
 88        ax.set_yticklabels(feature_names)
 89        ax.set_title(title)
 90
 91        for i in range(q):
 92            for j in range(q):
 93                ax.text(
 94                    j, i, f"{matrix[i, j]:.3f}",
 95                    ha="center", va="center", fontsize=7,
 96                )
 97
 98        fig.tight_layout()
 99        return fig, ax
100
101    def weighted_importance(
102        self,
103        X_background,
104        top_n: Optional[int] = None,
105        figsize: Tuple[int, int] = (8, 5),
106        ax=None,
107    ):
108        """
109        Rank interaction pairs by the range of w_{jk} * h_{jk}(x).
110
111        The range (max - min) over the background data measures how much
112        the pair actually varies — a large weight with a flat h_{jk} adds
113        almost nothing. This gives a fairer importance metric than |w_{jk}|.
114
115        Args:
116            X_background: Background data for evaluating h_{jk}.
117            top_n: Plot only top N pairs. None = all.
118            figsize: Figure size.
119            ax: Existing axes.
120
121        Returns:
122            (fig, ax, importance_dict).
123        """
124        plt = _get_plt()
125        contribs = self.model.pair_contributions(X_background)
126
127        importance = {}
128        for (fname_j, fname_k), vals in contribs.items():
129            label = f"{fname_j}" if fname_j == fname_k else f"{fname_j} x {fname_k}"
130            importance[label] = float(vals.max() - vals.min())
131
132        sorted_items = sorted(importance.items(), key=lambda x: x[1], reverse=True)
133        if top_n is not None:
134            sorted_items = sorted_items[:top_n]
135
136        labels = [it[0] for it in sorted_items]
137        values = [it[1] for it in sorted_items]
138
139        if ax is None:
140            fig, ax = plt.subplots(figsize=figsize)
141        else:
142            fig = ax.get_figure()
143
144        colors = [
145            "#2196F3" if " x " not in label else "#FF5722"
146            for label in labels
147        ]
148
149        ax.barh(labels[::-1], values[::-1], color=colors[::-1])
150        ax.set_xlabel("Range of w_{jk} * h_{jk}(x) on background")
151        ax.set_title("PIN Pair Importance (main effects=blue, interactions=orange)")
152        ax.grid(axis="x", alpha=0.3)
153        fig.tight_layout()
154
155        return fig, ax, dict(sorted_items)
156
157    def plot_main_effect(
158        self,
159        feature: str,
160        X_background,
161        n_grid: int = 100,
162        figsize: Tuple[int, int] = (6, 4),
163        ax=None,
164        color: str = "#2196F3",
165    ):
166        """
167        Plot main effect curve for a single feature.
168
169        Evaluates the diagonal pair term w_{jj} * h_{jj}(x_j) over a grid
170        of feature values while fixing other features to background means.
171
172        Args:
173            feature: Feature name.
174            X_background: Background data.
175            n_grid: Grid resolution.
176            figsize: Figure size.
177            ax: Existing axes.
178            color: Line colour.
179
180        Returns:
181            (fig, ax) tuple.
182        """
183        plt = _get_plt()
184        effects = self.model.main_effects(X_background, n_grid=n_grid)
185
186        if feature not in effects:
187            raise ValueError(f"Feature '{feature}' not found. Available: {list(effects.keys())}")
188
189        grid_vals, effect_vals = effects[feature]
190
191        if ax is None:
192            fig, ax = plt.subplots(figsize=figsize)
193        else:
194            fig = ax.get_figure()
195
196        spec = self.model.features[feature]
197        if isinstance(spec, int):
198            ax.bar(grid_vals, effect_vals, color=color, alpha=0.8)
199            ax.set_xlabel(f"{feature} (category)")
200        else:
201            ax.plot(grid_vals, effect_vals, color=color, linewidth=2)
202            ax.fill_between(grid_vals, effect_vals, alpha=0.15, color=color)
203            ax.set_xlabel(feature)
204
205        ax.axhline(0, color="gray", linestyle="--", alpha=0.5)
206        ax.set_ylabel("w_{jj} * h_{jj}(x) (linear predictor scale)")
207        ax.set_title(f"Main Effect: {feature}")
208        ax.grid(alpha=0.3)
209        fig.tight_layout()
210
211        return fig, ax
212
213    def plot_surface(
214        self,
215        feature_j: str,
216        feature_k: str,
217        X_background,
218        n_grid: int = 30,
219        figsize: Tuple[int, int] = (7, 5),
220        cmap: str = "RdBu_r",
221        ax=None,
222    ):
223        """
224        Plot 2D interaction surface for a feature pair.
225
226        Evaluates w_{jk} * h_{jk}(x_j, x_k) over a grid. Points with
227        high absolute values indicate combinations where the interaction
228        adds or subtracts substantially from the linear predictor.
229
230        Args:
231            feature_j: First feature (x-axis for continuous).
232            feature_k: Second feature (y-axis or color dimension).
233            X_background: Background data for axis range estimation.
234            n_grid: Grid resolution per axis.
235            figsize: Figure size.
236            cmap: Colormap.
237            ax: Existing axes.
238
239        Returns:
240            (fig, ax) tuple.
241        """
242        plt = _get_plt()
243        surfaces = self.model.interaction_surfaces(
244            X_background,
245            n_grid=n_grid,
246            pairs=[(feature_j, feature_k)],
247        )
248
249        key = (feature_j, feature_k)
250        if key not in surfaces:
251            key = (feature_k, feature_j)
252
253        if key not in surfaces:
254            raise ValueError(
255                f"Pair ({feature_j}, {feature_k}) not found in surfaces. "
256                "Ensure both features exist and j < k in feature order."
257            )
258
259        surf_data = surfaces[key]
260        grid_j = surf_data["grid_j"]
261        grid_k = surf_data["grid_k"]
262        surface = surf_data["surface"]
263
264        if ax is None:
265            fig, ax = plt.subplots(figsize=figsize)
266        else:
267            fig = ax.get_figure()
268
269        vmax = np.abs(surface).max()
270        vmin = -vmax
271
272        spec_j = self.model.features[feature_j]
273        spec_k = self.model.features[feature_k]
274
275        if isinstance(spec_k, int):
276            for cat_idx in range(grid_k.shape[0]):
277                ax.plot(
278                    grid_j,
279                    surface[:, cat_idx],
280                    label=f"{feature_k}={int(grid_k[cat_idx])}",
281                    alpha=0.8,
282                )
283            ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
284            ax.set_xlabel(feature_j)
285            ax.set_ylabel("w * h (interaction)")
286        else:
287            im = ax.pcolormesh(
288                grid_j, grid_k, surface.T,
289                cmap=cmap, vmin=vmin, vmax=vmax,
290            )
291            plt.colorbar(im, ax=ax, label="w_{jk} * h_{jk}(x)")
292            ax.set_xlabel(feature_j)
293            ax.set_ylabel(feature_k)
294
295        ax.set_title(f"Interaction Surface: {feature_j} x {feature_k}")
296        fig.tight_layout()
297
298        return fig, ax
299
300    def plot_training_history(
301        self,
302        figsize: Tuple[int, int] = (8, 4),
303        ax=None,
304    ):
305        """
306        Plot training and validation loss curves.
307
308        Args:
309            figsize: Figure size.
310            ax: Existing axes.
311
312        Returns:
313            (fig, ax) tuple.
314        """
315        plt = _get_plt()
316        history = self.model.train_history
317
318        if ax is None:
319            fig, ax = plt.subplots(figsize=figsize)
320        else:
321            fig = ax.get_figure()
322
323        epochs = range(len(history["train_loss"]))
324        ax.plot(epochs, history["train_loss"], label="Train", color="#2196F3")
325        ax.plot(epochs, history["val_loss"], label="Validation", color="#FF5722")
326        ax.set_xlabel("Epoch")
327        ax.set_ylabel("Loss (deviance)")
328        ax.set_title("PIN Training History")
329        ax.legend()
330        ax.grid(alpha=0.3)
331        fig.tight_layout()
332
333        return fig, ax
334
335    def summary(self, X_background=None) -> str:
336        """
337        Print a text summary of the fitted model.
338
339        Args:
340            X_background: Optional background data for pair importance ranking.
341
342        Returns:
343            Summary string.
344        """
345        model = self.model
346        lines = [
347            "=" * 60,
348            "PIN Model Summary",
349            "=" * 60,
350            f"Features:   {model.q}",
351            f"Parameters: {model.count_parameters():,}",
352            f"Loss:       {model.loss_name}",
353            f"Embedding:  d={model.embedding_dim}, d'={model.hidden_dim}, d0={model.token_dim}",
354            f"Shared net: d1={model.shared_dims[0]}, d2={model.shared_dims[1]}",
355            f"n_pairs:    {model.q * (model.q + 1) // 2} "
356            f"({model.q} main + {model.q*(model.q-1)//2} interactions)",
357            "",
358            "Feature list:",
359        ]
360        for name in model.feature_names:
361            spec = model.features[name]
362            ftype = "continuous" if spec == "continuous" else f"categorical ({spec} levels)"
363            lines.append(f"  {name}: {ftype}")
364
365        lines.append("")
366        lines.append("Output weights (|w_{jk}|, top 10):")
367        weights = model.interaction_weights()
368        sorted_w = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True)[:10]
369        for (fj, fk), w in sorted_w:
370            tag = "(main)" if fj == fk else "(interaction)"
371            lines.append(f"  {fj} x {fk}: {w:+.4f} {tag}")
372
373        if model.train_history["val_loss"]:
374            best_val = min(model.train_history["val_loss"])
375            lines.append(f"\nBest val loss: {best_val:.6f}")
376            lines.append(f"Epochs run:   {len(model.train_history['val_loss'])}")
377
378        lines.append("=" * 60)
379        txt = "\n".join(lines)
380        print(txt)
381        return txt

Visualisation tools for a fitted PINModel.

Arguments:
  • model: Fitted PINModel instance.
PINDiagnostics(model)
34    def __init__(self, model) -> None:
35        self.model = model
model
def interaction_heatmap( self, figsize: Tuple[int, int] = (8, 7), cmap: str = 'RdBu_r', title: str = 'PIN Interaction Weights |w_{jk}|', ax=None):
37    def interaction_heatmap(
38        self,
39        figsize: Tuple[int, int] = (8, 7),
40        cmap: str = "RdBu_r",
41        title: str = "PIN Interaction Weights |w_{jk}|",
42        ax=None,
43    ):
44        """
45        Heatmap of interaction weight magnitudes.
46
47        Rows and columns are features. The upper triangle (j<k) shows
48        interactions; the diagonal shows main effects.
49
50        Large values indicate pairs that contribute more variance to the
51        linear predictor. Note: weight magnitude alone doesn't measure
52        importance — the range of h_{jk} also matters. Use
53        weighted_importance() for a more meaningful ranking.
54
55        Args:
56            figsize: Figure size.
57            cmap: Colormap.
58            title: Plot title.
59            ax: Existing axes to draw on (optional).
60
61        Returns:
62            (fig, ax) tuple.
63        """
64        plt = _get_plt()
65        weights = self.model.interaction_weights()
66        feature_names = self.model.feature_names
67        q = len(feature_names)
68
69        matrix = np.zeros((q, q))
70        for (fname_j, fname_k), w in weights.items():
71            j = feature_names.index(fname_j)
72            k = feature_names.index(fname_k)
73            matrix[j, k] = abs(w)
74            if j != k:
75                matrix[k, j] = abs(w)
76
77        if ax is None:
78            fig, ax = plt.subplots(figsize=figsize)
79        else:
80            fig = ax.get_figure()
81
82        im = ax.imshow(matrix, cmap=cmap, aspect="auto")
83        plt.colorbar(im, ax=ax, label="|w_{jk}|")
84
85        ax.set_xticks(range(q))
86        ax.set_yticks(range(q))
87        ax.set_xticklabels(feature_names, rotation=45, ha="right")
88        ax.set_yticklabels(feature_names)
89        ax.set_title(title)
90
91        for i in range(q):
92            for j in range(q):
93                ax.text(
94                    j, i, f"{matrix[i, j]:.3f}",
95                    ha="center", va="center", fontsize=7,
96                )
97
98        fig.tight_layout()
99        return fig, ax

Heatmap of interaction weight magnitudes.

Rows and columns are features. The upper triangle (j

Large values indicate pairs that contribute more variance to the linear predictor. Note: weight magnitude alone doesn't measure importance — the range of h_{jk} also matters. Use weighted_importance() for a more meaningful ranking.

Arguments:
  • figsize: Figure size.
  • cmap: Colormap.
  • title: Plot title.
  • ax: Existing axes to draw on (optional).
Returns:

(fig, ax) tuple.

def weighted_importance( self, X_background, top_n: Optional[int] = None, figsize: Tuple[int, int] = (8, 5), ax=None):
101    def weighted_importance(
102        self,
103        X_background,
104        top_n: Optional[int] = None,
105        figsize: Tuple[int, int] = (8, 5),
106        ax=None,
107    ):
108        """
109        Rank interaction pairs by the range of w_{jk} * h_{jk}(x).
110
111        The range (max - min) over the background data measures how much
112        the pair actually varies — a large weight with a flat h_{jk} adds
113        almost nothing. This gives a fairer importance metric than |w_{jk}|.
114
115        Args:
116            X_background: Background data for evaluating h_{jk}.
117            top_n: Plot only top N pairs. None = all.
118            figsize: Figure size.
119            ax: Existing axes.
120
121        Returns:
122            (fig, ax, importance_dict).
123        """
124        plt = _get_plt()
125        contribs = self.model.pair_contributions(X_background)
126
127        importance = {}
128        for (fname_j, fname_k), vals in contribs.items():
129            label = f"{fname_j}" if fname_j == fname_k else f"{fname_j} x {fname_k}"
130            importance[label] = float(vals.max() - vals.min())
131
132        sorted_items = sorted(importance.items(), key=lambda x: x[1], reverse=True)
133        if top_n is not None:
134            sorted_items = sorted_items[:top_n]
135
136        labels = [it[0] for it in sorted_items]
137        values = [it[1] for it in sorted_items]
138
139        if ax is None:
140            fig, ax = plt.subplots(figsize=figsize)
141        else:
142            fig = ax.get_figure()
143
144        colors = [
145            "#2196F3" if " x " not in label else "#FF5722"
146            for label in labels
147        ]
148
149        ax.barh(labels[::-1], values[::-1], color=colors[::-1])
150        ax.set_xlabel("Range of w_{jk} * h_{jk}(x) on background")
151        ax.set_title("PIN Pair Importance (main effects=blue, interactions=orange)")
152        ax.grid(axis="x", alpha=0.3)
153        fig.tight_layout()
154
155        return fig, ax, dict(sorted_items)

Rank interaction pairs by the range of w_{jk} * h_{jk}(x).

The range (max - min) over the background data measures how much the pair actually varies — a large weight with a flat h_{jk} adds almost nothing. This gives a fairer importance metric than |w_{jk}|.

Arguments:
  • X_background: Background data for evaluating h_{jk}.
  • top_n: Plot only top N pairs. None = all.
  • figsize: Figure size.
  • ax: Existing axes.
Returns:

(fig, ax, importance_dict).

def plot_main_effect( self, feature: str, X_background, n_grid: int = 100, figsize: Tuple[int, int] = (6, 4), ax=None, color: str = '#2196F3'):
157    def plot_main_effect(
158        self,
159        feature: str,
160        X_background,
161        n_grid: int = 100,
162        figsize: Tuple[int, int] = (6, 4),
163        ax=None,
164        color: str = "#2196F3",
165    ):
166        """
167        Plot main effect curve for a single feature.
168
169        Evaluates the diagonal pair term w_{jj} * h_{jj}(x_j) over a grid
170        of feature values while fixing other features to background means.
171
172        Args:
173            feature: Feature name.
174            X_background: Background data.
175            n_grid: Grid resolution.
176            figsize: Figure size.
177            ax: Existing axes.
178            color: Line colour.
179
180        Returns:
181            (fig, ax) tuple.
182        """
183        plt = _get_plt()
184        effects = self.model.main_effects(X_background, n_grid=n_grid)
185
186        if feature not in effects:
187            raise ValueError(f"Feature '{feature}' not found. Available: {list(effects.keys())}")
188
189        grid_vals, effect_vals = effects[feature]
190
191        if ax is None:
192            fig, ax = plt.subplots(figsize=figsize)
193        else:
194            fig = ax.get_figure()
195
196        spec = self.model.features[feature]
197        if isinstance(spec, int):
198            ax.bar(grid_vals, effect_vals, color=color, alpha=0.8)
199            ax.set_xlabel(f"{feature} (category)")
200        else:
201            ax.plot(grid_vals, effect_vals, color=color, linewidth=2)
202            ax.fill_between(grid_vals, effect_vals, alpha=0.15, color=color)
203            ax.set_xlabel(feature)
204
205        ax.axhline(0, color="gray", linestyle="--", alpha=0.5)
206        ax.set_ylabel("w_{jj} * h_{jj}(x) (linear predictor scale)")
207        ax.set_title(f"Main Effect: {feature}")
208        ax.grid(alpha=0.3)
209        fig.tight_layout()
210
211        return fig, ax

Plot main effect curve for a single feature.

Evaluates the diagonal pair term w_{jj} * h_{jj}(x_j) over a grid of feature values while fixing other features to background means.

Arguments:
  • feature: Feature name.
  • X_background: Background data.
  • n_grid: Grid resolution.
  • figsize: Figure size.
  • ax: Existing axes.
  • color: Line colour.
Returns:

(fig, ax) tuple.

def plot_surface( self, feature_j: str, feature_k: str, X_background, n_grid: int = 30, figsize: Tuple[int, int] = (7, 5), cmap: str = 'RdBu_r', ax=None):
213    def plot_surface(
214        self,
215        feature_j: str,
216        feature_k: str,
217        X_background,
218        n_grid: int = 30,
219        figsize: Tuple[int, int] = (7, 5),
220        cmap: str = "RdBu_r",
221        ax=None,
222    ):
223        """
224        Plot 2D interaction surface for a feature pair.
225
226        Evaluates w_{jk} * h_{jk}(x_j, x_k) over a grid. Points with
227        high absolute values indicate combinations where the interaction
228        adds or subtracts substantially from the linear predictor.
229
230        Args:
231            feature_j: First feature (x-axis for continuous).
232            feature_k: Second feature (y-axis or color dimension).
233            X_background: Background data for axis range estimation.
234            n_grid: Grid resolution per axis.
235            figsize: Figure size.
236            cmap: Colormap.
237            ax: Existing axes.
238
239        Returns:
240            (fig, ax) tuple.
241        """
242        plt = _get_plt()
243        surfaces = self.model.interaction_surfaces(
244            X_background,
245            n_grid=n_grid,
246            pairs=[(feature_j, feature_k)],
247        )
248
249        key = (feature_j, feature_k)
250        if key not in surfaces:
251            key = (feature_k, feature_j)
252
253        if key not in surfaces:
254            raise ValueError(
255                f"Pair ({feature_j}, {feature_k}) not found in surfaces. "
256                "Ensure both features exist and j < k in feature order."
257            )
258
259        surf_data = surfaces[key]
260        grid_j = surf_data["grid_j"]
261        grid_k = surf_data["grid_k"]
262        surface = surf_data["surface"]
263
264        if ax is None:
265            fig, ax = plt.subplots(figsize=figsize)
266        else:
267            fig = ax.get_figure()
268
269        vmax = np.abs(surface).max()
270        vmin = -vmax
271
272        spec_j = self.model.features[feature_j]
273        spec_k = self.model.features[feature_k]
274
275        if isinstance(spec_k, int):
276            for cat_idx in range(grid_k.shape[0]):
277                ax.plot(
278                    grid_j,
279                    surface[:, cat_idx],
280                    label=f"{feature_k}={int(grid_k[cat_idx])}",
281                    alpha=0.8,
282                )
283            ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
284            ax.set_xlabel(feature_j)
285            ax.set_ylabel("w * h (interaction)")
286        else:
287            im = ax.pcolormesh(
288                grid_j, grid_k, surface.T,
289                cmap=cmap, vmin=vmin, vmax=vmax,
290            )
291            plt.colorbar(im, ax=ax, label="w_{jk} * h_{jk}(x)")
292            ax.set_xlabel(feature_j)
293            ax.set_ylabel(feature_k)
294
295        ax.set_title(f"Interaction Surface: {feature_j} x {feature_k}")
296        fig.tight_layout()
297
298        return fig, ax

Plot 2D interaction surface for a feature pair.

Evaluates w_{jk} * h_{jk}(x_j, x_k) over a grid. Points with high absolute values indicate combinations where the interaction adds or subtracts substantially from the linear predictor.

Arguments:
  • feature_j: First feature (x-axis for continuous).
  • feature_k: Second feature (y-axis or color dimension).
  • X_background: Background data for axis range estimation.
  • n_grid: Grid resolution per axis.
  • figsize: Figure size.
  • cmap: Colormap.
  • ax: Existing axes.
Returns:

(fig, ax) tuple.

def plot_training_history(self, figsize: Tuple[int, int] = (8, 4), ax=None):
300    def plot_training_history(
301        self,
302        figsize: Tuple[int, int] = (8, 4),
303        ax=None,
304    ):
305        """
306        Plot training and validation loss curves.
307
308        Args:
309            figsize: Figure size.
310            ax: Existing axes.
311
312        Returns:
313            (fig, ax) tuple.
314        """
315        plt = _get_plt()
316        history = self.model.train_history
317
318        if ax is None:
319            fig, ax = plt.subplots(figsize=figsize)
320        else:
321            fig = ax.get_figure()
322
323        epochs = range(len(history["train_loss"]))
324        ax.plot(epochs, history["train_loss"], label="Train", color="#2196F3")
325        ax.plot(epochs, history["val_loss"], label="Validation", color="#FF5722")
326        ax.set_xlabel("Epoch")
327        ax.set_ylabel("Loss (deviance)")
328        ax.set_title("PIN Training History")
329        ax.legend()
330        ax.grid(alpha=0.3)
331        fig.tight_layout()
332
333        return fig, ax

Plot training and validation loss curves.

Arguments:
  • figsize: Figure size.
  • ax: Existing axes.
Returns:

(fig, ax) tuple.

def summary(self, X_background=None) -> str:
335    def summary(self, X_background=None) -> str:
336        """
337        Print a text summary of the fitted model.
338
339        Args:
340            X_background: Optional background data for pair importance ranking.
341
342        Returns:
343            Summary string.
344        """
345        model = self.model
346        lines = [
347            "=" * 60,
348            "PIN Model Summary",
349            "=" * 60,
350            f"Features:   {model.q}",
351            f"Parameters: {model.count_parameters():,}",
352            f"Loss:       {model.loss_name}",
353            f"Embedding:  d={model.embedding_dim}, d'={model.hidden_dim}, d0={model.token_dim}",
354            f"Shared net: d1={model.shared_dims[0]}, d2={model.shared_dims[1]}",
355            f"n_pairs:    {model.q * (model.q + 1) // 2} "
356            f"({model.q} main + {model.q*(model.q-1)//2} interactions)",
357            "",
358            "Feature list:",
359        ]
360        for name in model.feature_names:
361            spec = model.features[name]
362            ftype = "continuous" if spec == "continuous" else f"categorical ({spec} levels)"
363            lines.append(f"  {name}: {ftype}")
364
365        lines.append("")
366        lines.append("Output weights (|w_{jk}|, top 10):")
367        weights = model.interaction_weights()
368        sorted_w = sorted(weights.items(), key=lambda x: abs(x[1]), reverse=True)[:10]
369        for (fj, fk), w in sorted_w:
370            tag = "(main)" if fj == fk else "(interaction)"
371            lines.append(f"  {fj} x {fk}: {w:+.4f} {tag}")
372
373        if model.train_history["val_loss"]:
374            best_val = min(model.train_history["val_loss"])
375            lines.append(f"\nBest val loss: {best_val:.6f}")
376            lines.append(f"Epochs run:   {len(model.train_history['val_loss'])}")
377
378        lines.append("=" * 60)
379        txt = "\n".join(lines)
380        print(txt)
381        return txt

Print a text summary of the fitted model.

Arguments:
  • X_background: Optional background data for pair importance ranking.
Returns:

Summary string.

def centered_hard_sigmoid(x: torch.Tensor) -> torch.Tensor:
20def centered_hard_sigmoid(x: torch.Tensor) -> torch.Tensor:
21    """
22    The activation used in PIN interaction units.
23
24    Defined in the paper as:
25        sigma_hard(x) = clamp((1 + x) / 2, 0, 1)
26
27    This maps:
28        x = 0  -> 0.5  (the 'centered' sense — zero input gives midpoint)
29        x = -1 -> 0.0
30        x =  1 -> 1.0
31
32    Note: NOT torch.nn.Hardsigmoid, which uses clamp((x+3)/6, 0, 1) — a
33    different shift chosen to match sigmoid(-3)=0.5 behaviour. The PIN paper
34    uses the simpler (1+x)/2 formulation.
35
36    Args:
37        x: Input tensor, any shape.
38
39    Returns:
40        Tensor of same shape, values in [0, 1].
41    """
42    return torch.clamp((1.0 + x) / 2.0, min=0.0, max=1.0)

The activation used in PIN interaction units.

Defined in the paper as:

sigma_hard(x) = clamp((1 + x) / 2, 0, 1)

This maps:

x = 0 -> 0.5 (the 'centered' sense — zero input gives midpoint) x = -1 -> 0.0 x = 1 -> 1.0

Note: NOT torch.nn.Hardsigmoid, which uses clamp((x+3)/6, 0, 1) — a different shift chosen to match sigmoid(-3)=0.5 behaviour. The PIN paper uses the simpler (1+x)/2 formulation.

Arguments:
  • x: Input tensor, any shape.
Returns:

Tensor of same shape, values in [0, 1].