跳转至

IQN Implementation Plan

For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.

Goal: Add iqn (Implicit Quantile Network) as a first-class discrete-action algorithm, with trainer wiring, registry support, public API exports, config/example coverage, and tests.

Architecture: Reuse the existing train_dqn path so iqn plugs into the same replay-buffer, checkpoint, evaluation, and CLI workflow as the rest of the DQN family. Implement an MLPIQNetwork that conditions value estimates on sampled quantile fractions and an IQN algorithm that applies quantile Huber loss against target-network samples. Keep scope narrow: discrete actions only, MLP observations only, no prioritized/offline hybrid behavior in this change.

Tech Stack: Python 3.10, PyTorch, Gymnasium, pytest


Task 1: Add failing IQN coverage

Files: - Create: tests/test_iqn_update.py - Create: tests/test_iqn_reference_script.py - Modify: tests/test_dqn_trainer_smoke.py - Modify: tests/test_cli.py - Modify: tests/test_public_api.py - Modify: tests/test_package_api_exports.py

Step 1: Write the failing test

Add tests that expect: - MLPIQNetwork returns quantile samples with shape (batch, action_dim, num_quantiles) - MLPIQNetwork.act() selects via the mean over sampled quantiles - IQN.update() returns standard DQN-family metrics - algo: iqn trains through the shared DQN trainer and writes a checkpoint - CLI train --config supports algo: iqn - managed API exports IQN - the reference script runs successfully

Step 2: Run test to verify it fails

Run: PYTHONPATH=src pytest tests/test_iqn_update.py tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_iqn_reference_script.py -q Expected: FAIL because the model, algorithm, and wiring do not exist yet.


Task 2: Implement the IQN model and algorithm

Files: - Create: src/rl_training/models/mlp_iqn_network.py - Create: src/rl_training/algorithms/iqn.py - Modify: src/rl_training/models/__init__.py - Modify: src/rl_training/algorithms/__init__.py

Step 1: Write minimal implementation

Implement: - MLPIQNetwork with cosine quantile embeddings, sampled quantile fractions, and .act() based on the mean over sampled quantiles - IQN with target network, sampled target quantiles, quantile Huber loss, state serialization, and train/eval mode helpers

Step 2: Run focused tests

Run: PYTHONPATH=src pytest tests/test_iqn_update.py -q Expected: PASS


Task 3: Wire IQN into the shared DQN trainer

Files: - Modify: src/rl_training/runtime/dqn_trainer.py

Step 1: Update trainer factories

Add: - _build_q_network() branch for config.algo == "iqn" returning MLPIQNetwork - _build_algorithm() branch returning IQN - trainer/evaluation typing updates so iqn follows the same training and greedy evaluation path as qr_dqn

Step 2: Run focused tests

Run: PYTHONPATH=src pytest tests/test_dqn_trainer_smoke.py::test_train_iqn_writes_checkpoint_and_metrics -q Expected: PASS


Task 4: Wire registry + managed API + exports

Files: - Modify: src/rl_training/experiment/registry.py - Modify: src/rl_training/api/algorithms.py - Modify: src/rl_training/api/__init__.py - Modify: src/rl_training/__init__.py - Modify: src/rl_training/algorithms/__init__.py

Step 1: Registry load/evaluate/predict

Add: - _load_iqn_algorithm() - _evaluate_iqn() - _predict_iqn() - _ALGORITHM_REGISTRY["iqn"]

Step 2: Managed API exports

Add class IQN(ManagedAlgorithm) with algo_name = "iqn" and expose it from public packages.

Step 3: Run focused tests

Run: PYTHONPATH=src pytest tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py -q Expected: PASS


Task 5: Add config + example

Files: - Create: configs/iqn/cartpole.yaml - Create: examples/iqn_cartpole_reference.py

Step 1: Write minimal implementation

Add a CartPole config with IQN-specific hyperparameters such as num_quantiles, num_tau_samples, and embedding_dim. Add a reference script consistent with the existing example layout.

Step 2: Run focused tests

Run: PYTHONPATH=src pytest tests/test_iqn_reference_script.py -q Expected: PASS


Task 6: Verify end-to-end

Files: - Verify only

Step 1: Run targeted verification

Run: PYTHONPATH=src pytest tests/test_iqn_update.py tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_iqn_reference_script.py -q Expected: PASS

Step 2: Run full verification

Run: PYTHONPATH=src pytest -q Expected: PASS