跳转至

QR-DQN Implementation Plan

For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.

Goal: Add qr_dqn (Quantile Regression DQN) as a first-class discrete-action algorithm, with configs, examples, registry wiring, public API exports, and tests.

Architecture: Reuse the existing train_dqn trainer entrypoint and checkpoint flow to minimize churn. Implement a distributional Q-network (MLPQRQNetwork) that outputs quantiles with shape (batch, action_dim, num_quantiles) and selects actions via the mean over quantiles. Implement QRDQN as a dedicated algorithm with a target network and quantile Huber loss. Wire qr_dqn through the registry load/evaluate/predict paths so CLI and managed API work unchanged.

Tech Stack: Python 3.10, PyTorch, Gymnasium, pytest


Task 1: Add failing QR-DQN coverage

Files: - Create: tests/test_qr_dqn_update.py - Modify: tests/test_dqn_trainer_smoke.py - Modify: tests/test_cli.py - Modify: tests/test_public_api.py - Modify: tests/test_package_api_exports.py - Create: tests/test_qr_dqn_reference_script.py

Step 1: Write the failing test

Add tests that expect: - MLPQRQNetwork outputs correct shape and act() uses mean-over-quantiles - QRDQN.update() returns named metrics (loss/q_value_mean/target_mean/td_error_mean) - algo: qr_dqn can be trained via registry and writes a checkpoint - CLI train --config works with algo: qr_dqn - managed API exports QRDQN - the reference script runs successfully

Step 2: Run test to verify it fails

Run: pytest tests/test_qr_dqn_update.py tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_qr_dqn_reference_script.py -q

Expected: FAIL because the model/algorithm/wiring does not exist yet.


Task 2: Implement the model and algorithm

Files: - Create: src/rl_training/models/mlp_qr_q_network.py - Create: src/rl_training/algorithms/qr_dqn.py - Modify: src/rl_training/models/__init__.py - Modify: src/rl_training/algorithms/__init__.py

Step 1: Write minimal implementation

Implement: - MLPQRQNetwork with .forward() returning quantiles and .act() selecting via mean - QRDQN with target network sync + quantile Huber loss + .state_dict()/.load_state_dict()

Step 2: Run focused tests

Run: pytest tests/test_qr_dqn_update.py -q

Expected: PASS


Task 3: Wire QR-DQN into the DQN trainer

Files: - Modify: src/rl_training/runtime/dqn_trainer.py

Step 1: Update trainer factories

Add: - _build_q_network() branch for config.algo == "qr_dqn" returning MLPQRQNetwork - _build_algorithm() branch returning QRDQN

Step 2: Run focused tests

Run: pytest tests/test_dqn_trainer_smoke.py::test_train_qr_dqn_writes_checkpoint_and_metrics -q

Expected: PASS


Task 4: Wire registry + managed API

Files: - Modify: src/rl_training/experiment/registry.py - Modify: src/rl_training/api/algorithms.py - Modify: src/rl_training/api/__init__.py - Modify: src/rl_training/algorithms/__init__.py - Modify: src/rl_training/__init__.py

Step 1: Registry load/evaluate/predict

Add a dedicated loader (similar to C51) so checkpoints can be evaluated/predicted: - _load_qr_dqn_algorithm() - _evaluate_qr_dqn() - _predict_qr_dqn()

Add _ALGORITHM_REGISTRY["qr_dqn"].

Step 2: Managed API exports

Add class QRDQN(ManagedAlgorithm) with algo_name = "qr_dqn", and export via rl_training.api and rl_training top-level.

Step 3: Run focused tests

Run: pytest tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py -q

Expected: PASS


Task 5: Add config + example + README mention

Files: - Create: configs/qr_dqn/cartpole.yaml - Create: examples/qr_dqn_cartpole_reference.py - Modify: README.md

Step 1: Write minimal implementation

Add a CartPole config including algo_kwargs.num_quantiles (e.g. 51) and an example script consistent with other DQN examples.

Step 2: Run focused tests

Run: pytest tests/test_qr_dqn_reference_script.py -q

Expected: PASS


Task 6: Verify end-to-end

Files: - Verify only

Run: - Targeted: pytest tests/test_qr_dqn_update.py tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_qr_dqn_reference_script.py -q - Full: pytest -q

Expected: PASS