跳转至

N-Step DQN Implementation Plan

For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.

Goal: Add n_step_dqn as a first-class DQN-family algorithm using n-step returns for improved credit assignment, with configs, examples, registry wiring, public API exports, and tests.

Architecture: Reuse the existing train_dqn runtime loop and checkpoint flow. Add an NStepAccumulator that converts 1-step environment interactions into n-step transitions ((s_t, a_t, R_{t:t+n}, s_{t+n}, done)), compatible with the existing ReplayBuffer.add(). Avoid changing the DQN update rule by passing an effective discount to the algorithm: keep reward aggregation using the base gamma, and set the algorithm’s gamma to gamma ** n_step so the existing target computation remains correct.

Tech Stack: Python 3.10, PyTorch, Gymnasium, pytest


Task 1: Add failing n-step coverage

Files: - Modify: tests/test_dqn_trainer_smoke.py - Modify: tests/test_cli.py - Modify: tests/test_public_api.py - Modify: tests/test_package_api_exports.py - Create: tests/test_n_step_dqn_reference_script.py

Step 1: Write the failing test

Add tests that expect: - algo: n_step_dqn can be trained via registry and writes a checkpoint - CLI train --config works with algo: n_step_dqn - managed API exports NStepDQN - the reference script runs successfully

Step 2: Run test to verify it fails

Run: pytest tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_n_step_dqn_reference_script.py -q

Expected: FAIL because n_step_dqn wiring does not exist yet.


Task 2: Implement n-step transition accumulation in the DQN trainer

Files: - Create: src/rl_training/data/n_step.py - Modify: src/rl_training/runtime/dqn_trainer.py

Step 1: Write minimal implementation

Implement NStepAccumulator(num_envs, n_step, gamma) with: - per-env queues (vector-env safe) - add(env_index, obs, action, reward, next_obs, done) -> list[transition_dict] - flush remaining transitions on done=True

Then update train_dqn() so that when config.algo == "n_step_dqn": - transitions are fed through NStepAccumulator before adding to replay - the algorithm is constructed with effective_gamma = gamma ** n_step

Step 2: Run focused tests

Run: pytest tests/test_dqn_trainer_smoke.py::test_train_n_step_dqn_writes_checkpoint_and_metrics -q

Expected: PASS


Task 3: Wire n_step_dqn into registry and public API

Files: - Modify: src/rl_training/experiment/registry.py - Modify: src/rl_training/api/algorithms.py - Modify: src/rl_training/api/__init__.py - Modify: src/rl_training/algorithms/__init__.py - Modify: src/rl_training/__init__.py

Step 1: Wire the algorithm spec

Add _ALGORITHM_REGISTRY["n_step_dqn"] with: - train_fn=train_dqn - evaluate_fn=_evaluate_dqn - predict_fn=_predict_dqn

Step 2: Add managed API

Expose class NStepDQN(ManagedAlgorithm) with algo_name = "n_step_dqn" and export it via rl_training.api and rl_training top-level.

Step 3: Run focused tests

Run: pytest tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py -q

Expected: PASS


Task 4: Add config + example + README mention

Files: - Create: configs/n_step_dqn/cartpole.yaml - Create: examples/n_step_dqn_cartpole_reference.py - Modify: README.md

Step 1: Write minimal implementation

Add: - CartPole config including algo_kwargs.n_step (e.g. 3) and algo_kwargs.gamma - a runnable reference script consistent with other examples - README algorithm list mention

Step 2: Run focused tests

Run: pytest tests/test_n_step_dqn_reference_script.py -q

Expected: PASS


Task 5: Verify end-to-end

Files: - Verify only

Run: - Targeted: pytest tests/test_dqn_trainer_smoke.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_n_step_dqn_reference_script.py -q - Full: pytest -q

Expected: PASS