跳转至

DRQN V1 Implementation Plan

For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.

Goal: Add a narrow but honest DRQN baseline to rl_training for discrete-action, vector-observation environments through the standard train/eval/resume/predict package surface.

Architecture: Keep this release deliberately small. Reuse the existing DQN package shape, but implement recurrence explicitly instead of overloading the feed-forward DQN trainer. Add an LSTM-backed Q-network, a recurrent replay buffer that stores fixed-length chunks with initial hidden state, and a dedicated drqn trainer. Do not attempt Atari/image support, distributed replay, burn-in, or multi-step off-policy corrections in this batch.

Tech Stack: Python 3.10, PyTorch, Gymnasium, pytest, existing rl_training config/checkpoint/API stack.


Task 1: Add failing DRQN coverage

Files: - Create: tests/test_drqn_update.py - Create: tests/test_drqn_trainer_smoke.py - Create: tests/test_drqn_reference_script.py - Create: tests/test_recurrent_replay_buffer.py - Modify: tests/test_package_api_exports.py - Modify: tests/test_public_api.py

Step 1: Write the failing test - LSTMQNetwork returns Q-values, actions, and recurrent state for vector observations. - DRQN.update() returns named metrics on sequence batches. - RecurrentReplayBuffer stores chunks with masks and initial hidden state. - train_drqn() writes checkpoint and eval metrics on CartPole-v1. - public API exports DRQN. - reference script runs as a smoke command.

Step 2: Run test to verify it fails Run: pytest -q tests/test_drqn_update.py tests/test_drqn_trainer_smoke.py tests/test_drqn_reference_script.py tests/test_recurrent_replay_buffer.py tests/test_package_api_exports.py tests/test_public_api.py Expected: FAIL with missing drqn modules / exports.

Task 2: Implement recurrent Q model and replay buffer

Files: - Create: src/rl_training/models/recurrent/lstm_q_network.py - Create: src/rl_training/data/recurrent_replay_buffer.py - Modify: src/rl_training/models/recurrent/__init__.py - Modify: src/rl_training/models/__init__.py - Modify: src/rl_training/data/__init__.py

Step 1: Write minimal implementation - LSTMQNetwork with: - vector-observation encoder MLP - LSTM core - Q-head - initial_state(), reset_state(), act(), q_values_sequence() - RecurrentReplayBuffer with: - per-env chunk assembly - fixed-length sequence storage - masks and episode-start flags - stored initial hidden state per chunk - sample() returning (time, batch, ...) tensors

Step 2: Run tests to verify it passes Run: pytest -q tests/test_drqn_update.py::test_lstm_q_network_vector_obs_returns_q_values_actions_and_state tests/test_recurrent_replay_buffer.py Expected: PASS.

Task 3: Implement DRQN algorithm and trainer

Files: - Create: src/rl_training/algorithms/drqn.py - Create: src/rl_training/runtime/drqn_trainer.py

Step 1: Write minimal implementation - DRQN algorithm with: - target network - masked sequence TD loss - target-network sync on interval - train_drqn() with: - vector discrete env support only - recurrent online collection with epsilon-greedy actions - recurrent replay chunk storage - periodic evaluation and checkpointing - _evaluate_drqn_policy() for greedy recurrent rollout.

Step 2: Run tests to verify it passes Run: pytest -q tests/test_drqn_update.py tests/test_drqn_trainer_smoke.py Expected: PASS.

Task 4: Wire package surfaces and examples

Files: - Create: examples/drqn_cartpole_reference.py - Create: configs/drqn/cartpole.yaml - Create: src/rl_training/assets/configs/drqn/cartpole.yaml - Modify: src/rl_training/algorithms/__init__.py - Modify: src/rl_training/api/algorithms.py - Modify: src/rl_training/api/__init__.py - Modify: src/rl_training/__init__.py - Modify: src/rl_training/experiment/registry.py - Modify: README.md - Modify: docs/plans/2026-03-12-rl-yearly-sourcebook-design.md

Step 1: Write minimal implementation - Add registry load/eval/predict functions. - Add managed API class DRQN. - Add root/api/algorithms exports. - Add config + packaged asset config. - Add reference script. - Update README and yearly sourcebook to mark DRQN as implemented narrow v1.

Step 2: Run tests to verify it passes Run: pytest -q tests/test_drqn_reference_script.py tests/test_package_api_exports.py tests/test_public_api.py tests/test_cli.py Expected: PASS.

Task 5: Regression verification

Files: - Modify only if verification reveals regressions.

Step 1: Run focused regression coverage Run: pytest -q tests/test_drqn_update.py tests/test_drqn_trainer_smoke.py tests/test_drqn_reference_script.py tests/test_recurrent_replay_buffer.py tests/test_package_api_exports.py tests/test_public_api.py tests/test_cli.py tests/test_dqn_update.py tests/test_dqn_trainer_smoke.py tests/test_recurrent_models.py tests/test_recurrent_ppo_update.py Expected: PASS.

Step 2: Run full suite Run: pytest -q Expected: PASS.