Metadata-Version: 2.4
Name: rl-interrogate
Version: 0.1.0
Summary: Mechanistic interpretability toolkit for RL policies
Project-URL: Homepage, https://github.com/aroransh/rl_interrogate
Project-URL: Repository, https://github.com/aroransh/rl_interrogate
Project-URL: Documentation, https://github.com/aroransh/rl_interrogate#readme
Author: Ansh Arora
License-Expression: MIT
License-File: LICENSE
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Requires-Dist: gymnasium
Requires-Dist: matplotlib
Requires-Dist: numpy
Requires-Dist: scikit-learn
Requires-Dist: seaborn
Requires-Dist: stable-baselines3
Requires-Dist: torch>=2.0
Provides-Extra: dev
Requires-Dist: hypothesis; extra == 'dev'
Requires-Dist: pytest; extra == 'dev'
Provides-Extra: mujoco
Requires-Dist: mujoco; extra == 'mujoco'
Description-Content-Type: text/markdown

# rl_interrogate

Mechanistic interpretability toolkit for RL policies. Probe, ablate, and interrogate
what your policy has learned — not just how well it performs.

## Installation

```bash
pip install -e .
```

Dependencies: `torch`, `numpy`, `scikit-learn`, `matplotlib`, `seaborn`, `gymnasium`,
`stable-baselines3`. MuJoCo environments require the `mujoco` extra:

```bash
pip install -e ".[mujoco]"
```

## Minimal Example

Load a checkpoint, run a probe, run ablation:

```python
import torch
import numpy as np
from rl_interrogate import LinearProbe, AblationHook

# 1. Load your policy network (any torch.nn.Sequential)
policy_net = torch.load("my_policy.pt")
policy_net.eval()

# 2. Build a synthetic observation grid
obs_grid = np.random.randn(500, 28).astype(np.float32)
labels = obs_grid[:, 10]  # probe for lateral position

# 3. Linear probe at layer 5
probe = LinearProbe()
probe.fit(policy_net, layer_idx=5, obs_dataset=obs_grid, labels=labels)
print(f"Layer 5 R² = {probe.score():.4f}")

# 4. Ablation: zero the probe direction, measure performance change
hook = AblationHook(policy_net, layer_idx=5, direction=probe._probe.coef_)
with hook.apply(alpha=0.0):
    # run your environment here — the probe direction is zeroed
    pass
```

## Experiments

The library was developed for the WakeRider paper (TMLR submission). Key experiments:

- **Formation flight probe** (`examples/formation_flight_probe.py`): Reproduces
  Actor L5 R²=0.973 from the seed-42 checkpoint.
- **HalfCheetah ablation** (`examples/halfcheetah_ablation.py`): Runs ablation on
  a HalfCheetah-v4 policy, showing PC1 ablation degrades performance by ~10%.

## API Reference

### Probing

```python
from rl_interrogate import LinearProbe, MLPProbe, LassoProbe

# Ridge regression probe (recommended)
probe = LinearProbe()
probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
r2 = probe.score()

# MLP probe (non-linear)
mlp_probe = MLPProbe()
mlp_probe.fit(model, layer_idx=5, obs_dataset=obs, labels=y)

# Sparse Lasso probe
lasso = LassoProbe()
lasso.fit(model, layer_idx=5, obs_dataset=obs, labels=y)
r2, n_nonzero = lasso.score()
```

### Ablation

```python
from rl_interrogate import AblationHook

hook = AblationHook(policy_net, layer_idx=5, direction=probe_direction)
with hook.apply(alpha=0.0):   # alpha=0 zeros the direction
    rewards = run_episodes(model, env, n=100)
```

### PCA Utilities

```python
from rl_interrogate import fit_pca, project_subspace

pca = fit_pca(activations, n_components=20)
acts_k = project_subspace(activations, pca, k=1)  # rank-1 projection
```

### Visualization

```python
from rl_interrogate import plot_probe_heatmap, plot_ablation_curve

plot_probe_heatmap(activations, labels, title="Layer 5 probe")
plot_ablation_curve(alphas=[0.0, 0.5, 1.0], means=[1.05, 0.97, 0.90])
```

## Running Tests

```bash
pytest rl_interrogate/tests/ -v
```

## Link to Paper

This library implements the interrogation protocol described in:

> *WakeRider: Emergent V-Formation Flight via Wake Exploitation*
> Section 3.3: The rl_interrogate Library

The protocol consists of four steps:
1. **Linear probing** — fit Ridge regression from hidden activations to a field label
2. **Polarity inversion** — negate the sensor; verify R² drops (causal, not correlational)
3. **Single-direction ablation** — zero the probe direction; measure performance change
4. **Subspace variance** — greedy PCA selection to find the minimal sufficient subspace
