IQN Implementation Plan¶
For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
Goal: Add iqn (Implicit Quantile Network) as a first-class discrete-action algorithm, with trainer wiring, registry support, public API exports, config/example coverage, and tests.
Architecture: Reuse the existing train_dqn path so iqn plugs into the same replay-buffer, checkpoint, evaluation, and CLI workflow as the rest of the DQN family. Implement an MLPIQNetwork that conditions value estimates on sampled quantile fractions and an IQN algorithm that applies quantile Huber loss against target-network samples. Keep scope narrow: discrete actions only, MLP observations only, no prioritized/offline hybrid behavior in this change.
Tech Stack: Python 3.10, PyTorch, Gymnasium, pytest
Task 1: Add failing IQN coverage¶
Files: - Create: tests/test_iqn_update.py - Create: tests/test_iqn_reference_script.py - Modify: tests/test_dqn_trainer_smoke.py - Modify: tests/test_cli.py - Modify: tests/test_public_api.py - Modify: tests/test_package_api_exports.py
Step 1: Write the failing test
Add tests that expect: - MLPIQNetwork returns quantile samples with shape (batch, action_dim, num_quantiles) - MLPIQNetwork.act() selects via the mean over sampled quantiles - IQN.update() returns standard DQN-family metrics - algo: iqn trains through the shared DQN trainer and writes a checkpoint - CLI train --config supports algo: iqn - managed API exports IQN - the reference script runs successfully
Step 2: Run test to verify it fails
Run: PYTHONPATH=src pytest tests/test_iqn_update.py tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_iqn_reference_script.py -q Expected: FAIL because the model, algorithm, and wiring do not exist yet.
Task 2: Implement the IQN model and algorithm¶
Files: - Create: src/rl_training/models/mlp_iqn_network.py - Create: src/rl_training/algorithms/iqn.py - Modify: src/rl_training/models/__init__.py - Modify: src/rl_training/algorithms/__init__.py
Step 1: Write minimal implementation
Implement: - MLPIQNetwork with cosine quantile embeddings, sampled quantile fractions, and .act() based on the mean over sampled quantiles - IQN with target network, sampled target quantiles, quantile Huber loss, state serialization, and train/eval mode helpers
Step 2: Run focused tests
Run: PYTHONPATH=src pytest tests/test_iqn_update.py -q Expected: PASS
Task 3: Wire IQN into the shared DQN trainer¶
Files: - Modify: src/rl_training/runtime/dqn_trainer.py
Step 1: Update trainer factories
Add: - _build_q_network() branch for config.algo == "iqn" returning MLPIQNetwork - _build_algorithm() branch returning IQN - trainer/evaluation typing updates so iqn follows the same training and greedy evaluation path as qr_dqn
Step 2: Run focused tests
Run: PYTHONPATH=src pytest tests/test_dqn_trainer_smoke.py::test_train_iqn_writes_checkpoint_and_metrics -q Expected: PASS
Task 4: Wire registry + managed API + exports¶
Files: - Modify: src/rl_training/experiment/registry.py - Modify: src/rl_training/api/algorithms.py - Modify: src/rl_training/api/__init__.py - Modify: src/rl_training/__init__.py - Modify: src/rl_training/algorithms/__init__.py
Step 1: Registry load/evaluate/predict
Add: - _load_iqn_algorithm() - _evaluate_iqn() - _predict_iqn() - _ALGORITHM_REGISTRY["iqn"]
Step 2: Managed API exports
Add class IQN(ManagedAlgorithm) with algo_name = "iqn" and expose it from public packages.
Step 3: Run focused tests
Run: PYTHONPATH=src pytest tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py -q Expected: PASS
Task 5: Add config + example¶
Files: - Create: configs/iqn/cartpole.yaml - Create: examples/iqn_cartpole_reference.py
Step 1: Write minimal implementation
Add a CartPole config with IQN-specific hyperparameters such as num_quantiles, num_tau_samples, and embedding_dim. Add a reference script consistent with the existing example layout.
Step 2: Run focused tests
Run: PYTHONPATH=src pytest tests/test_iqn_reference_script.py -q Expected: PASS
Task 6: Verify end-to-end¶
Files: - Verify only
Step 1: Run targeted verification
Run: PYTHONPATH=src pytest tests/test_iqn_update.py tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_iqn_reference_script.py -q Expected: PASS
Step 2: Run full verification
Run: PYTHONPATH=src pytest -q Expected: PASS