Metadata-Version: 2.4
Name: probekit
Version: 0.2.2
Summary: A versatile kit for training and using linear probes on neural network activations.
Author: Probekit Contributors
License: MIT
Project-URL: Homepage, https://github.com/ZuiderveldTimJ/probekit
Project-URL: Repository, https://github.com/ZuiderveldTimJ/probekit
Project-URL: Issues, https://github.com/ZuiderveldTimJ/probekit/issues
Keywords: interpretability,safety,llm,probes
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: scikit-learn>=1.3.0
Requires-Dist: numpy>=1.24.0
Requires-Dist: tqdm>=4.65.0
Requires-Dist: sae-lens>=3.0.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: ruff==0.1.6; extra == "dev"
Requires-Dist: mypy>=1.0.0; extra == "dev"
Requires-Dist: build; extra == "dev"
Requires-Dist: twine; extra == "dev"
Requires-Dist: pre-commit; extra == "dev"
Dynamic: license-file

# Probes

A lightweight, modular library for training linear probes and steering vectors on neural network activations.

## Core Design (V2)

This library separates **Semantics** (the probe model) from **Fitting** (how it's learned).

### 1. The Models: `LinearProbe` and `ProbeCollection`
- **`LinearProbe`** (`probekit.core.probe`): A container for a single probe (+ normalization stats).
- **`ProbeCollection`** (`probekit.core.collection`): A container for a **batch** of probes.
    - `to_tensor()`: Stacks weights into `[B, D]` and biases into `[B]`.
    - `best_layer(metric)`: Finds the probe with the best validation accuracy.

### 2. The Fitters
Functional solvers in `probekit.fitters` take training data and return a `LinearProbe` (or `ProbeCollection`).

- `fit_logistic`: Standard L2-regularized Logistic Regression.
- `fit_elastic_net`: ElasticNet (L1 + L2), useful for sparse features (SAEs, Neurons).
- `fit_dim`: Difference-in-Means (Class 1 Mean - Class 0 Mean).

#### Batched GPU Fitters
Optimized PyTorch implementations in `probekit.fitters.batch` handle 3D inputs `[B, N, D]` efficiently on GPU:
- `fit_logistic_batch`: Batched IRLS solver.
- `fit_dim_batch`: Vectorized DiM with median thresholding.
- `fit_elastic_net_path`: Efficiently fits a regularization path (multiple alphas) using warm-starting.

## Quick Start

The high-level API automatically routes based on the input dimensions:

```python
from probekit import sae_probe, dim_probe

# 1. Single Probe (X: [N, D], y: [N])
probe = sae_probe(X_2d, y_1d)

# 2. Batched Probes (X: [B, N, D], y: [B, N] or [N])
# Automatically uses GPU fitters and returns a ProbeCollection
probes = sae_probe(X_3d, y)
weights, biases = probes.to_tensor() # [B, D], [B]
```

## Steering Vectors

You can build steering vectors for individual probes or entire collections:

```python
from probekit import build_steering_vector, build_steering_vectors

# Single
vec = build_steering_vector(probe, sae_model, layer=10)

# Batched (Maps layers to probes)
vecs = build_steering_vectors(probe_collection, sae_model, layers=[8, 9, 10])
```

## Structure

- `probekit/core/`: `LinearProbe` and `ProbeCollection` definitions.
- `probekit/fitters/`:
    - `logistic.py`, `elastic.py`, `dim.py`: Single-probe (CPU/sklearn) fitters.
    - `batch/`: Optimized GPU-batched fitters (IRLS, ISTA, DiM).
- `probekit/api.py`: High-level aliases and dimension routing.
- `probekit/steering/`: Tools for building steering vectors.
