Metadata-Version: 2.4
Name: disco-torch
Version: 0.1.0
Summary: PyTorch port of DeepMind's Disco103 meta-learned RL update rule
Author: asystemoffields
License-Expression: Apache-2.0
Project-URL: Homepage, https://github.com/asystemoffields/disco-torch
Project-URL: Repository, https://github.com/asystemoffields/disco-torch
Project-URL: Issues, https://github.com/asystemoffields/disco-torch/issues
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0
Requires-Dist: numpy>=1.24
Provides-Extra: hub
Requires-Dist: huggingface-hub>=0.20; extra == "hub"
Provides-Extra: examples
Requires-Dist: gymnasium>=1.0; extra == "examples"
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"
Requires-Dist: gymnasium>=1.0; extra == "dev"
Requires-Dist: huggingface-hub>=0.20; extra == "dev"
Provides-Extra: jax
Requires-Dist: jax; extra == "jax"
Requires-Dist: dm-haiku; extra == "jax"
Requires-Dist: rlax; extra == "jax"
Requires-Dist: distrax; extra == "jax"
Dynamic: license-file

# disco-torch

A PyTorch port of DeepMind's **Disco103** — the meta-learned reinforcement learning update rule from [*Discovering State-of-the-art Reinforcement Learning Algorithms*](https://doi.org/10.1038/s41586-025-09761-x) (Nature, 2025).

## What is DiscoRL?

Instead of hand-crafted loss functions like PPO or GRPO, DiscoRL uses a small LSTM neural network (the "meta-network") that **generates loss targets** for RL agents. Given a rollout of agent experience — policy logits, rewards, advantages, auxiliary predictions — the meta-network outputs target distributions. The agent then minimizes KL divergence between its outputs and these learned targets.

The Disco103 checkpoint (754,778 parameters) was meta-trained by DeepMind across thousands of Atari-like environments. It generalizes as a drop-in update rule for new tasks — no reward shaping, no hyperparameter-specific loss design.

## Why a PyTorch port?

The [original implementation](https://github.com/google-deepmind/disco_rl) uses JAX + Haiku. This port enables using Disco103 in PyTorch training pipelines without any JAX dependency at inference time.

## Installation

```bash
pip install disco-torch
```

With optional extras:

```bash
pip install disco-torch[hub]       # HuggingFace Hub weight downloads
pip install disco-torch[examples]  # gymnasium for running examples
pip install disco-torch[dev]       # pytest + all extras for development
```

### Weights

Option 1 — Download from HuggingFace Hub (requires `pip install disco-torch[hub]`):

```python
from disco_torch import load_disco103_weights

rule = DiscoUpdateRule()
load_disco103_weights(rule)  # auto-downloads from HuggingFace Hub
```

Option 2 — Manual download from the [disco_rl repo](https://github.com/google-deepmind/disco_rl):

```bash
cp path/to/disco_103.npz weights/
```

```python
load_disco103_weights(rule, "weights/disco_103.npz")
```

## Quick start

```python
import torch
from disco_torch import DiscoUpdateRule, UpdateRuleInputs, load_disco103_weights

# Load the meta-network with pretrained weights
rule = DiscoUpdateRule()
load_disco103_weights(rule, "weights/disco_103.npz")

# Initialize meta-RNN state (persists across training steps)
state = rule.meta_net.initial_meta_rnn_state()

# Run the meta-network on a rollout
with torch.no_grad():
    meta_out, new_state = rule.meta_net(inputs, state)
    # meta_out["pi"]  — policy loss targets  [T, B, A]
    # meta_out["y"]   — value loss targets   [T, B, 600]
    # meta_out["z"]   — auxiliary loss targets [T, B, 600]
```

### Full training loop

```python
# At each learner step:
meta_out, new_meta_state = rule.unroll_meta_net(
    rollout, agent_params, meta_state, unroll_fn, hyper_params
)

# Compute agent loss (KL divergence against meta-network targets)
loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)

# Value function loss (no meta-gradient)
value_loss, value_logs = rule.agent_loss_no_meta(rollout, meta_out, hyper_params)
```

## Architecture

```
Outer (per-trajectory):
  y_net           MLP [600 -> 16 -> 1]         Value prediction embedding
  z_net           MLP [600 -> 16 -> 1]         Auxiliary prediction embedding
  policy_net      Conv1dNet [9 -> 16 -> 2]     Action-conditional embedding
  trajectory_rnn  LSTM(27, 256)                Reverse-unrolled over trajectory
  state_gate      Linear(128 -> 256)           Multiplicative gate from meta-RNN
  y_head / z_head Linear(256 -> 600)           Loss targets for y and z
  pi_conv + head  Conv1dNet [258 -> 16] -> 1   Policy loss target (per action)

Meta-RNN (per-lifetime):
  Separate y/z/policy nets, input MLP(29 -> 16), LSTMCell(16, 128)
```

The outer network processes each trajectory with a reverse-unrolled LSTM. The meta-RNN operates at a slower timescale — it sees batch-time averages and modulates the outer network via a multiplicative gate. This two-level architecture lets the update rule adapt its behavior over an agent's lifetime.

## End-to-end example: Tiny Transformer

The flagship example trains a **1-layer causal transformer** (~141K params) to generate strictly increasing digit sequences, using Disco103 as the RL update rule. No supervised data — the agent learns purely from a per-token reward signal.

**[Open in Google Colab](examples/tinylm_disco_colab.ipynb)** (runs on free T4 GPU)

Using curriculum learning (sequence lengths 4→6→8), the agent achieves **86% strictly increasing** on 8-token sequences, discovering sequences like `[0, 1, 5, 6, 5, 7, 8, 9]`. See [`experiments.md`](experiments.md) for the full research log.

```bash
# Standalone script (local GPU or CPU)
python examples/tinylm_disco.py --weights weights/disco_103.npz
```

A simpler CartPole example is also provided for API reference:

```bash
python examples/cartpole_disco.py --weights weights/disco_103.npz
```

> **Note:** Disco103 was meta-trained on 103 complex environments (Atari, ProcGen, DMLab-30). Both CartPole and token generation are outside this distribution. See [`experiments.md`](experiments.md) for analysis of how this affects transfer.

## Package structure

```
disco_torch/
  __init__.py          Public API
  types.py             Dataclasses: UpdateRuleInputs, MetaNetInputOption, ValueOuts, etc.
  transforms.py        Input transforms and construct_input()
  meta_net.py          DiscoMetaNet — the full LSTM meta-network
  update_rule.py       DiscoUpdateRule — meta-net + value computation + loss
  value_utils.py       V-trace, TD-error, advantage estimation, Q-values
  utils.py             batch_lookup, signed_logp1, 2-hot encoding, EMA
  load_weights.py      Maps JAX/Haiku NPZ keys -> PyTorch modules

examples/
  tinylm_disco.py          Train a tiny transformer with Disco103 (standalone)
  tinylm_disco_colab.ipynb Google Colab notebook (recommended)
  cartpole_disco.py        CartPole API reference example

scripts/
  inspect_disco103.py      Print NPZ weight names and shapes
  validate_against_jax.py  Numerical comparison: PyTorch vs JAX reference

tests/
  test_utils.py            Unit tests for utility functions
  test_building_blocks.py  Unit tests for network building blocks
  test_meta_net.py         Snapshot tests for meta-network forward pass
```

## Numerical validation

All outputs match the JAX reference implementation within float32 precision:

| Output | Max diff | Status |
|--------|----------|--------|
| pi (policy targets) | < 1.3e-06 | PASS |
| y (value targets) | < 1.3e-06 | PASS |
| z (auxiliary targets) | < 1.3e-06 | PASS |
| meta_input_emb | < 1.3e-06 | PASS |
| meta_rnn_h | < 1.3e-06 | PASS |

To run the test suite (no JAX required):

```bash
pip install disco-torch[dev]
pytest
```

To run JAX cross-validation (requires JAX + disco_rl):

```bash
pip install disco_rl jax dm-haiku rlax distrax
python scripts/validate_against_jax.py
```

## Key implementation details

- **HaikuLSTMCell**: Haiku uses gate order `[i, g, f, o]` with a +1 forget gate bias, vs PyTorch's `[i, f, g, o]`. This is handled by a custom LSTM cell.
- **Weight mapping**: The 42 JAX/Haiku parameters have nested path names (e.g., `lstm/~/meta_lstm/~unroll/mlp_2/~/linear_0/w`). `load_weights.py` maps every one to the correct PyTorch module.
- **Conv1dBlock**: Each block concatenates per-action features with their mean across actions before the convolution — matching the JAX implementation's broadcast pattern.
- **Value utilities**: V-trace, Retrace-style Q-value estimation, signed hyperbolic transforms, and 2-hot categorical encoding are all ported.

## Requirements

- Python >= 3.11
- PyTorch >= 2.0
- NumPy >= 1.24

## License

Apache 2.0 — same as the original disco_rl.

## Citation

If you use this port, please cite the original paper:

```bibtex
@article{oh2025disco,
  title={Discovering State-of-the-art Reinforcement Learning Algorithms},
  author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa and Singh, Satinder and van Hasselt, Hado and Silver, David},
  journal={Nature},
  volume={648},
  pages={312--319},
  year={2025},
  doi={10.1038/s41586-025-09761-x}
}
```

## Acknowledgments

This is a community port of [google-deepmind/disco_rl](https://github.com/google-deepmind/disco_rl). All credit for the algorithm, architecture, and pretrained weights goes to the original authors.
