Metadata-Version: 2.4
Name: subset-mixture-model
Version: 0.1.2
Summary: Interpretable empirical-Bayes aggregation of partition estimators for categorical regression
Author-email: Aaron John Danielson <aaron.danielson@austin.utexas.edu>
License: MIT
Project-URL: Repository, https://github.com/aaronjdanielson/subset-mixture-model
Keywords: interpretable machine learning,categorical features,empirical Bayes,uncertainty quantification,mixture model
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: torch>=2.0
Requires-Dist: numpy>=1.24
Requires-Dist: pandas>=2.0
Requires-Dist: scikit-learn>=1.3
Requires-Dist: scipy>=1.11
Provides-Extra: experiments
Requires-Dist: matplotlib>=3.7; extra == "experiments"
Requires-Dist: seaborn>=0.12; extra == "experiments"
Requires-Dist: joblib>=1.3; extra == "experiments"
Requires-Dist: lightgbm>=4.0; extra == "experiments"
Requires-Dist: ngboost>=0.4; extra == "experiments"
Requires-Dist: mapie>=0.6; extra == "experiments"
Provides-Extra: dev
Requires-Dist: pytest>=7; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"

# Subset Mixture Model (SMM)

[![PyPI version](https://badge.fury.io/py/subset-mixture-model.svg)](https://pypi.org/project/subset-mixture-model/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

**SMM** is an interpretable, uncertainty-aware regression method for datasets with categorical features. It learns a weighted average of partition estimators—one per non-empty feature subset—and tells you exactly *why* each prediction has the value and uncertainty it does.

---

## Key idea

Each feature subset *s* groups training data by its unique value combinations and stores the empirical mean and variance per group. SMM learns a single global weight vector π over all 2^D − 1 subsets that minimizes negative log-likelihood:

$$\hat{f}(\mathbf{x}) = \sum_{s \in \mathcal{S}} \hat{\pi}_s \cdot \hat{\mu}_{m(s,\mathbf{x})}(s)$$

The learned weights directly answer: *which feature combinations matter?* Predictions are convex combinations of verifiable training-data statistics—no black box.

**Oracle guarantee:** the learned mixture achieves within log(|S|)/n of the best single-subset estimator in log-loss.

---

## Installation

```bash
pip install subset-mixture-model
```

Import as `smm`:

```python
import smm
```

---

## Complete worked example

The example below is fully self-contained. It creates a synthetic dataset with known structure, trains SMM, makes calibrated predictions with uncertainty estimates, then uses the diagnostic tools to trace *why* each prediction looks the way it does.

```python
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from smm import (
    SubsetMaker, SubsetWeightsModel, SubsetDataset,
    subset_mixture_neg_log_posterior,
    SubsetMixturePredictor,
    compute_posterior_covariance,
    predict_with_uncertainty,
    coverage,
    weight_table,
    explain_prediction,
    calibration_stats,
)

# ── 1. Synthetic data ──────────────────────────────────────────────────────────
#
# Three categorical features: region (2 values), season (3 values), tier (2 values).
# True signal: region drives a baseline; season × tier creates an interaction.
# SMM should discover both without being told which interactions to look for.

rng = np.random.default_rng(42)
N = 2000

region = rng.integers(0, 2, N)
season = rng.integers(0, 3, N)
tier   = rng.integers(0, 2, N)

y = 5.0 * region + 3.0 * (season == 1) * tier + rng.normal(0, 2.0, N)

df = pd.DataFrame({"region": region, "season": season, "tier": tier, "y": y})

idx = rng.permutation(N)
n_train, n_val = int(0.70 * N), int(0.15 * N)
train_df = df.iloc[idx[:n_train]].reset_index(drop=True)
val_df   = df.iloc[idx[n_train:n_train + n_val]].reset_index(drop=True)
test_df  = df.iloc[idx[n_train + n_val:]].reset_index(drop=True)

CAT_COLS = ["region", "season", "tier"]
TARGET   = "y"

# ── 2. Build the lookup table ──────────────────────────────────────────────────
#
# SubsetMaker enumerates all 2^3 − 1 = 7 non-empty feature subsets, groups the
# training data by value combinations within each subset, and stores the
# empirical (mean, variance) of the target per group.

subset_maker = SubsetMaker(train_df, CAT_COLS, [TARGET])
n_subsets = len(subset_maker.lookup)
print(f"Subsets: {n_subsets}")   # → 7

# ── 3. Train the weight model ──────────────────────────────────────────────────
#
# SubsetWeightsModel holds a single logit vector η ∈ R^|S|.
# The training loss is negative log-posterior: NLL of the Gaussian mixture
# plus a Dirichlet prior (alpha > 1 discourages degenerate weights).

model     = SubsetWeightsModel(n_subsets)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-4)
ALPHA     = 1.1

train_loader = DataLoader(
    SubsetDataset(train_df, CAT_COLS, [TARGET]), batch_size=256, shuffle=True
)
val_loader = DataLoader(
    SubsetDataset(val_df, CAT_COLS, [TARGET]), batch_size=256, shuffle=False
)

best_val, no_improve, best_state = float("inf"), 0, None

for epoch in range(300):
    model.train()
    for x, y_batch in train_loader:
        optimizer.zero_grad()
        mus, variances, mask = subset_maker.batch_lookup(x)
        subset_mixture_neg_log_posterior(
            model(), y_batch, mus, variances, mask, alpha=ALPHA
        ).backward()
        optimizer.step()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y_batch in val_loader:
            mus, variances, mask = subset_maker.batch_lookup(x)
            val_loss += subset_mixture_neg_log_posterior(
                model(), y_batch, mus, variances, mask, alpha=ALPHA
            ).item()
    val_loss /= len(val_loader)

    if val_loss < best_val:
        best_val, no_improve = val_loss, 0
        best_state = {k: v.clone() for k, v in model.state_dict().items()}
    else:
        no_improve += 1
    if no_improve >= 20 and epoch >= 50:
        break

model.load_state_dict(best_state)

# ── 4. Predictor and posterior covariance ──────────────────────────────────────
#
# The Laplace approximation treats the MAP estimate η̂ as the center of a
# Gaussian, computes the Hessian of the loss there, and propagates uncertainty
# to the simplex via the softmax Jacobian: Σ_π = J H⁻¹ Jᵀ.

pi_hat    = F.softmax(model.eta.detach(), dim=0)
predictor = SubsetMixturePredictor(subset_maker, pi_hat)
sigma_pi  = compute_posterior_covariance(
    model, subset_maker, train_df, CAT_COLS, TARGET, alpha=ALPHA
)

# ── 5. Predict ─────────────────────────────────────────────────────────────────

y_mean, y_std, aleatoric_std, epistemic_std = predict_with_uncertainty(
    predictor, sigma_pi, test_df, return_components=True
)
y_true = test_df[TARGET].values

print(f"Test RMSE:     {np.sqrt(np.mean((y_mean - y_true)**2)):.3f}")
print(f"95% coverage:  {coverage(y_true, y_mean, y_std, level=0.95):.3f}")

# ── 6. Diagnostic: which subsets drive predictions? ───────────────────────────
#
# weight_table() returns a DataFrame sorted by π_s descending.
# Because the true signal has a "region" main effect and a "season × tier"
# interaction, those two subsets should dominate.

wt = weight_table(subset_maker, pi_hat, top_k=5)
print("\nTop-5 subsets by weight:")
print(wt[["subset", "weight", "n_cells", "cumulative_weight"]].to_string(index=False))
# Expected output (approximately):
#          subset  weight  n_cells  cumulative_weight
#        (region,)    0.41        2               0.41
#   (season, tier)    0.33        6               0.74
#  (region, season)   0.11       ...              0.85

# ── 7. Diagnostic: why this prediction for one test point? ────────────────────
#
# explain_prediction() shows every subset that has a valid training cell for
# this point, along with its cell statistics, renormalized weight, and additive
# contribution to the final prediction.

row = test_df.iloc[[0]]
exp = explain_prediction(predictor, row)

print(f"\nPrediction: {exp.attrs['predicted_mean']:.3f}  "
      f"(true: {y_true[0]:.3f})")
print(f"Uncertainty: total={y_std[0]:.3f}  "
      f"aleatoric={aleatoric_std[0]:.3f}  epistemic={epistemic_std[0]:.3f}")
print("\nPer-subset breakdown:")
print(exp[["subset", "cell_mean", "masked_weight", "contribution"]].to_string(index=False))
# Each row is a training-data statistic you can verify from train_df directly.
# The prediction equals the sum of the "contribution" column.

# ── 8. Diagnostic: are the intervals well-calibrated? ────────────────────────

cal = calibration_stats(y_true, y_mean, y_std)
print("\nCalibration:")
print(cal.to_string(index=False))
# Values near nominal → well-calibrated.
# SMM is typically slightly conservative (empirical ≥ nominal).
```

---

## Understanding the diagnostics

### Weight table

```
      subset  weight  n_cells  cumulative_weight
    (region,)   0.41        2               0.41
(season, tier)   0.33        6               0.74
        ...
```

A high weight on `(region,)` means knowing the region alone explains a large share of variance. A high weight on `(season, tier)` means the season–tier interaction is informative *beyond* either feature alone. Subsets with negligible weight contribute little—the model has automatically selected which interactions matter.

### Prediction breakdown

```
         subset  cell_mean  masked_weight  contribution
       (region,)       8.1           0.52          4.21
   (season, tier)       6.8           0.36          2.45
  (region, season)      7.9           0.09          0.71
```

Every row traces back to a specific group of training examples. The predicted value equals the sum of the `contribution` column. This is not a black box.

### Uncertainty decomposition

| Component | Source | Grows when… |
|---|---|---|
| **Aleatoric** | Variance within matched training cells | Training cell has high spread |
| **Epistemic** | Laplace posterior over mixture weights | Data is sparse; fewer subsets are active |

### Calibration table

```
 nominal  empirical
    0.50       0.51
    0.80       0.82
    0.95       0.96
```

Values near the diagonal mean stated intervals contain the true value at the stated rate. A reliability diagram can be plotted with `plot_calibration()` (requires matplotlib).

---

## API reference

### Data and training

| Symbol | Description |
|---|---|
| `SubsetMaker(df, cat_cols, [target])` | Build powerset lookup table from training data |
| `SubsetWeightsModel(n_subsets)` | Trainable logit parameter η of length \|S\| |
| `SubsetDataset(df, cat_cols, [target])` | PyTorch Dataset wrapping a DataFrame |
| `subset_mixture_neg_log_posterior(logits, y, mus, vars, mask, alpha)` | Training loss (NLL + Dirichlet prior) |
| `subset_mixture_mse(logits, y, mus, mask)` | MSE loss for warmup or debugging |

### Inference

| Symbol | Description |
|---|---|
| `SubsetMixturePredictor(subset_maker, pi_hat)` | Inference wrapper |
| `predictor.predict(df, return_debug=False)` | Point predictions; `return_debug=True` also returns per-example weight matrix `[B, \|S\|]` and fallback mask |

### Uncertainty

| Symbol | Description |
|---|---|
| `compute_posterior_covariance(model, subset_maker, train_df, cat_cols, target, alpha)` | Laplace approximation → Σ_π `[S, S]` |
| `predict_with_uncertainty(predictor, sigma_pi, df, return_components=False)` | Mean + total std; optionally aleatoric and epistemic stds |
| `coverage(y_true, y_mean, y_std, level=0.95)` | Empirical interval coverage at given level |

### Diagnostics

| Symbol | Description |
|---|---|
| `weight_table(subset_maker, pi_hat, top_k=None)` | DataFrame of subsets ranked by learned weight |
| `explain_prediction(predictor, row_df)` | Per-subset contribution breakdown for one test point |
| `calibration_stats(y_true, y_mean, y_std, levels=None)` | Empirical vs. nominal coverage at multiple levels |
| `plot_calibration(y_true, y_mean, y_std, levels=None, ax=None)` | Reliability diagram (requires matplotlib) |

---

## Requirements

**Core:** `torch >= 2.0`, `numpy >= 1.24`, `pandas >= 2.0`, `scipy >= 1.11`

`plot_calibration` additionally requires `matplotlib`.

---

## Citation

```bibtex
@article{danielson2025smm,
  title   = {Subset Mixture Model: Interpretable Aggregation of Partition Estimators},
  author  = {Danielson, Aaron John},
  journal = {Transactions on Machine Learning Research},
  year    = {2025},
}
```

---

## License

MIT © Aaron John Danielson
