Metadata-Version: 2.4
Name: torch-supercluster
Version: 0.1.0
Summary: Supervised prototype-routing for interpretable tabular prediction
Author-email: Aaron John Danielson <aaron.danielson@austin.utexas.edu>
License: MIT
Project-URL: Homepage, https://github.com/aaronjdanielson/torch-supercluster
Project-URL: Bug Tracker, https://github.com/aaronjdanielson/torch-supercluster/issues
Keywords: machine learning,clustering,interpretability,tabular,prototypes,pytorch
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0
Provides-Extra: examples
Requires-Dist: numpy; extra == "examples"
Requires-Dist: scikit-learn; extra == "examples"
Requires-Dist: pandas; extra == "examples"
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"
Requires-Dist: numpy; extra == "dev"
Requires-Dist: scikit-learn; extra == "dev"
Dynamic: license-file

# torch-supercluster

**SuperCluster** learns K prediction prototypes jointly with a supervised objective. Each example is routed to prototypes via cross-attention, and the prediction is an exact convex combination of prototype-level predictions.

This gives you a model that is simultaneously:
- **Accurate** — matches or exceeds MLP accuracy on tabular benchmarks
- **Interpretable** — each prototype maps to a readable "prediction archetype"
- **Self-analysing** — over-specify K and the model tells you how many distinct prediction regimes the data contains

```
x → MLP encoder → cross-attention (→ prototype centers C) → soft routing π
prediction:  ŷ = σ(π · s)   [binary]
             softmax(π @ S)  [multi-class]
             π · s           [regression]
```

## Installation

```bash
pip install torch-supercluster
```

Requires Python ≥ 3.9 and PyTorch ≥ 2.0.

## Quick start

### Binary classification

```python
from supercluster import SuperCluster, center_diversity_loss
import torch, torch.nn as nn

model = SuperCluster(
    input_dim=10,
    embedding_dim=64,
    num_clusters=8,       # set larger than expected; surplus prototypes collapse
)

logits, cluster_weights = model(x)               # x: [N, 10]
loss = nn.BCEWithLogitsLoss()(logits, y)          # logits: [N, 1]
loss = loss + 0.1 * center_diversity_loss(model.centers)
loss.backward()
```

### Multi-class classification

```python
model = SuperCluster(
    input_dim=54, embedding_dim=128,
    num_clusters=8, num_classes=7,
)
logits, cluster_weights = model(x)               # logits: [N, 7]
loss = nn.CrossEntropyLoss()(logits, y_long)
```

### Regression

```python
model = SuperCluster(
    input_dim=10, embedding_dim=64, num_clusters=5,
)
preds, cluster_weights = model(x)                # preds: [N, 1], no sigmoid
loss = nn.MSELoss()(preds, y)
```

### Reading prototype predictions

```python
import torch.nn.functional as F

# Binary: prototype predicted probability
for k in range(model.num_clusters):
    prob = torch.sigmoid(model.prototype_scores[k]).item()
    print(f"Prototype {k}: P(y=1) = {prob:.3f}")

# Multi-class: prototype predicted class distribution
for k in range(model.num_clusters):
    probs = F.softmax(model.prototype_scores[k], dim=0)
    print(f"Prototype {k}: {probs.tolist()}")
```

## How it works

SuperCluster separates two concerns that prior prototype models conflate:

| Parameter | Role |
|-----------|------|
| `model.centers`  ∈ ℝ^{K×d} | Routing geometry — where prototypes live in embedding space |
| `model.prototype_scores`  ∈ ℝ^K or ℝ^{K×C} | Prediction — each prototype's output value |

The center-diversity loss pushes centers apart on the unit hypersphere without
touching prediction scores, preventing the "overconfident prototype" pathology
common when routing geometry and prediction are conflated.

### Predictive regime discovery

Set K larger than you expect and examine which prototypes are occupied. The model
concentrates assignment mass on exactly as many prototypes as the data has distinct
prediction behaviours — surplus prototypes collapse to zero routing mass. This gives
an operational measure of **how many prediction regimes the data contains**.

```python
from supercluster import effective_prototype_count

_, cw = model(X_test)
k_eff = effective_prototype_count(cw)          # entropy-based count
active = model.active_prototypes(cw)           # list of occupied prototype indices
print(f"K_eff = {k_eff:.2f}, K_active = {len(active)}/{model.num_clusters}")
```

## Model parameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `input_dim` | required | Raw feature dimension |
| `embedding_dim` | required | Latent dimension d |
| `num_clusters` | required | Number of prototypes K |
| `num_classes` | 1 | 1 = binary/regression; C > 1 = multi-class |
| `encoder_layers` | 4 | MLP encoder depth |
| `encoder_hidden_size` | 256 | MLP encoder width |
| `num_attn_heads` | 8 | Cross-attention heads (must divide `embedding_dim`) |
| `num_cross_attn_layers` | 2 | Cross-attention depth L |
| `dropout` | 0.1 | Dropout rate |

## Training recommendations

| Setting | Recommendation |
|---------|----------------|
| Optimizer | AdamW or Adam, lr = 1e-3 |
| Diversity weight | `λ_div = 0.1` — add `λ_div * center_diversity_loss(model.centers)` to main loss |
| K | 1.5–3× your expected regime count; performance is flat across K |
| Patience | Early-stop on validation loss with patience 50–100 epochs |

## Empirical results

| Dataset | MLP | SuperCluster | K_active |
|---------|-----|-------------|---------|
| NBA shots (binary, 72k) | 63.3% | 63.4% | 3/8 |
| Bank Marketing (binary, 45k) | 88.9% / 78.7% AUC | 89.0% / 77.8% AUC | 3/8 |
| Adult Income (binary, 45k) | 85.0% / 91.0% AUC | 85.2% / 90.7% AUC | 2/8 |
| Credit Default (binary, 30k) | 81.7% / 76.8% AUC | 82.0% / 76.8% AUC | 2/8 |
| Covertype (7-class, 581k) | 95.7% | 95.5% | 6/8 |

## Citation

If you use this package in academic work, please cite:

```bibtex
@article{danielson2026supercluster,
  title   = {SuperCluster: Learning Prediction Prototypes via
             Target-Guided Cross-Attention Clustering},
  author  = {Danielson, Aaron John},
  journal = {Data Mining and Knowledge Discovery},
  year    = {2026},
}
```

## License

MIT
