QR-DQN Implementation Plan¶
For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
Goal: Add qr_dqn (Quantile Regression DQN) as a first-class discrete-action algorithm, with configs, examples, registry wiring, public API exports, and tests.
Architecture: Reuse the existing train_dqn trainer entrypoint and checkpoint flow to minimize churn. Implement a distributional Q-network (MLPQRQNetwork) that outputs quantiles with shape (batch, action_dim, num_quantiles) and selects actions via the mean over quantiles. Implement QRDQN as a dedicated algorithm with a target network and quantile Huber loss. Wire qr_dqn through the registry load/evaluate/predict paths so CLI and managed API work unchanged.
Tech Stack: Python 3.10, PyTorch, Gymnasium, pytest
Task 1: Add failing QR-DQN coverage¶
Files: - Create: tests/test_qr_dqn_update.py - Modify: tests/test_dqn_trainer_smoke.py - Modify: tests/test_cli.py - Modify: tests/test_public_api.py - Modify: tests/test_package_api_exports.py - Create: tests/test_qr_dqn_reference_script.py
Step 1: Write the failing test
Add tests that expect: - MLPQRQNetwork outputs correct shape and act() uses mean-over-quantiles - QRDQN.update() returns named metrics (loss/q_value_mean/target_mean/td_error_mean) - algo: qr_dqn can be trained via registry and writes a checkpoint - CLI train --config works with algo: qr_dqn - managed API exports QRDQN - the reference script runs successfully
Step 2: Run test to verify it fails
Run: pytest tests/test_qr_dqn_update.py tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_qr_dqn_reference_script.py -q
Expected: FAIL because the model/algorithm/wiring does not exist yet.
Task 2: Implement the model and algorithm¶
Files: - Create: src/rl_training/models/mlp_qr_q_network.py - Create: src/rl_training/algorithms/qr_dqn.py - Modify: src/rl_training/models/__init__.py - Modify: src/rl_training/algorithms/__init__.py
Step 1: Write minimal implementation
Implement: - MLPQRQNetwork with .forward() returning quantiles and .act() selecting via mean - QRDQN with target network sync + quantile Huber loss + .state_dict()/.load_state_dict()
Step 2: Run focused tests
Run: pytest tests/test_qr_dqn_update.py -q
Expected: PASS
Task 3: Wire QR-DQN into the DQN trainer¶
Files: - Modify: src/rl_training/runtime/dqn_trainer.py
Step 1: Update trainer factories
Add: - _build_q_network() branch for config.algo == "qr_dqn" returning MLPQRQNetwork - _build_algorithm() branch returning QRDQN
Step 2: Run focused tests
Run: pytest tests/test_dqn_trainer_smoke.py::test_train_qr_dqn_writes_checkpoint_and_metrics -q
Expected: PASS
Task 4: Wire registry + managed API¶
Files: - Modify: src/rl_training/experiment/registry.py - Modify: src/rl_training/api/algorithms.py - Modify: src/rl_training/api/__init__.py - Modify: src/rl_training/algorithms/__init__.py - Modify: src/rl_training/__init__.py
Step 1: Registry load/evaluate/predict
Add a dedicated loader (similar to C51) so checkpoints can be evaluated/predicted: - _load_qr_dqn_algorithm() - _evaluate_qr_dqn() - _predict_qr_dqn()
Add _ALGORITHM_REGISTRY["qr_dqn"].
Step 2: Managed API exports
Add class QRDQN(ManagedAlgorithm) with algo_name = "qr_dqn", and export via rl_training.api and rl_training top-level.
Step 3: Run focused tests
Run: pytest tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py -q
Expected: PASS
Task 5: Add config + example + README mention¶
Files: - Create: configs/qr_dqn/cartpole.yaml - Create: examples/qr_dqn_cartpole_reference.py - Modify: README.md
Step 1: Write minimal implementation
Add a CartPole config including algo_kwargs.num_quantiles (e.g. 51) and an example script consistent with other DQN examples.
Step 2: Run focused tests
Run: pytest tests/test_qr_dqn_reference_script.py -q
Expected: PASS
Task 6: Verify end-to-end¶
Files: - Verify only
Run: - Targeted: pytest tests/test_qr_dqn_update.py tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_qr_dqn_reference_script.py -q - Full: pytest -q
Expected: PASS