跳转至

Atari + Recurrent PPO Phase 1 Implementation Plan

For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.

Goal: Turn rl_training into a mainstream-style RL package for classic control and Atari by adding pixel-observation support, Atari presets, a RecurrentPPO contrib algorithm, a first zoo/ layer, and the packaging/docs polish needed for real adoption.

Architecture: Keep the existing rl_training core stable while extending it in three deliberate directions. First, teach the current env and trainer stack to support Atari wrappers and CNN encoders so existing DQN/PPO flows can train on image observations. Second, isolate recurrent PPO into rl_training.contrib so sequence state and hidden-state checkpoint semantics do not leak into every core policy path. Third, add a zoo/ layer for benchmark presets, scripts, and docs, plus installable CLI metadata, so the new capability is visible and reproducible instead of being buried in ad hoc examples.

Tech Stack: Python 3.10, PyTorch, Gymnasium, ALE-compatible Atari environments, pytest, setuptools


Task 1: Add failing Atari environment and CNN coverage

Files: - Create: tests/test_atari_envs.py - Create: tests/test_nature_cnn.py - Modify: tests/test_envs.py - Modify: tests/test_module_contracts.py

Step 1: Write the failing test

Add tests that define: - make_env() applies Atari preprocessing only when env_kwargs or config tags request the Atari pipeline - Atari env factories can produce stacked image observations with stable shapes - NatureCNN accepts (batch, channels, height, width) observations and returns a fixed-size feature tensor - non-image classic-control flows remain unchanged

Step 2: Run test to verify it fails

Run: PYTHONPATH=src pytest tests/test_atari_envs.py tests/test_nature_cnn.py tests/test_envs.py tests/test_module_contracts.py -q Expected: FAIL because Atari preprocessing and CNN feature extraction do not exist yet.

Step 3: Write minimal implementation

Create the smallest Atari wrapper helpers and CNN model needed to satisfy the tests without changing trainer logic yet.

Step 4: Run test to verify it passes

Run: PYTHONPATH=src pytest tests/test_atari_envs.py tests/test_nature_cnn.py tests/test_envs.py tests/test_module_contracts.py -q Expected: PASS

Task 2: Implement Atari wrappers and image-ready env factories

Files: - Create: src/rl_training/envs/atari.py - Modify: src/rl_training/envs/factory.py - Modify: src/rl_training/envs/__init__.py - Create: src/rl_training/models/cnn/__init__.py - Create: src/rl_training/models/cnn/nature.py - Modify: src/rl_training/models/__init__.py

Step 1: Write minimal implementation

Implement: - a helper that detects Atari-style configs and builds the wrapper stack - support for grayscale, resize, frame stack, and channel-first conversion - NatureCNN with configurable output feature size - exports so CNN encoders are available to trainer/model code

Step 2: Run focused tests

Run: PYTHONPATH=src pytest tests/test_atari_envs.py tests/test_nature_cnn.py -q Expected: PASS

Task 3: Add image-observation support to DQN and PPO training paths

Files: - Create: tests/test_atari_dqn_trainer_smoke.py - Create: tests/test_atari_ppo_trainer_smoke.py - Modify: src/rl_training/runtime/dqn_trainer.py - Modify: src/rl_training/runtime/ppo_trainer.py - Modify: src/rl_training/experiment/registry.py - Modify: tests/test_checkpoint_workflows.py - Modify: tests/test_experiment_manager.py

Step 1: Write the failing tests

Add smoke tests that expect: - train_dqn() can train on a tiny Atari-configured env and emit a checkpoint - train_ppo() can train on a tiny Atari-configured env and emit a checkpoint - evaluate_checkpoint() works for Atari DQN and PPO checkpoints

Step 2: Run tests to verify they fail

Run: PYTHONPATH=src pytest tests/test_atari_dqn_trainer_smoke.py tests/test_atari_ppo_trainer_smoke.py tests/test_checkpoint_workflows.py tests/test_experiment_manager.py -q Expected: FAIL because the current trainers only support flat 1D observations and MLP-only loading.

Step 3: Write minimal implementation

Update trainer and registry logic so: - image observations can flow through CNN-based models - vector observations still use the current MLP code paths - checkpoint load/evaluate/predict infer the correct encoder family from config

Step 4: Run tests to verify they pass

Run: PYTHONPATH=src pytest tests/test_atari_dqn_trainer_smoke.py tests/test_atari_ppo_trainer_smoke.py tests/test_checkpoint_workflows.py tests/test_experiment_manager.py -q Expected: PASS

Task 4: Add Atari presets, reference scripts, and CLI coverage

Files: - Create: configs/dqn/breakout_atari.yaml - Create: configs/ppo/breakout_atari.yaml - Create: examples/dqn_breakout_atari_reference.py - Create: examples/ppo_breakout_atari_reference.py - Create: tests/test_atari_reference_scripts.py - Modify: tests/test_cli.py - Modify: README.md

Step 1: Write the failing tests

Add coverage for: - CLI train --config with Atari DQN and PPO configs - reference scripts that run a tiny smoke job with the Atari pipeline enabled - README usage snippets that mention the Atari configs

Step 2: Run tests to verify they fail

Run: PYTHONPATH=src pytest tests/test_cli.py tests/test_atari_reference_scripts.py -q Expected: FAIL because Atari configs, examples, and docs do not exist yet.

Step 3: Write minimal implementation

Add compact Atari configs and examples that use tiny smoke-test settings and document the intended commands.

Step 4: Run tests to verify they pass

Run: PYTHONPATH=src pytest tests/test_cli.py tests/test_atari_reference_scripts.py -q Expected: PASS

Task 5: Add failing recurrent rollout and recurrent policy coverage

Files: - Create: tests/test_recurrent_rollout_buffer.py - Create: tests/test_recurrent_ppo_update.py - Create: tests/test_recurrent_models.py - Modify: tests/test_module_contracts.py

Step 1: Write the failing tests

Add tests that define: - a recurrent rollout buffer can store hidden states, episode starts, and sequence masks - an LSTM actor-critic returns actions, values, logprobs, and next hidden state - recurrent_ppo_loss(...) or equivalent update logic returns named metrics - invalid sequence-length or hidden-size settings fail fast

Step 2: Run tests to verify they fail

Run: PYTHONPATH=src pytest tests/test_recurrent_rollout_buffer.py tests/test_recurrent_ppo_update.py tests/test_recurrent_models.py tests/test_module_contracts.py -q Expected: FAIL because recurrent buffer and recurrent policy components do not exist yet.

Step 3: Write minimal implementation

Create the buffer and model primitives with the smallest sequence semantics required by the tests.

Step 4: Run tests to verify they pass

Run: PYTHONPATH=src pytest tests/test_recurrent_rollout_buffer.py tests/test_recurrent_ppo_update.py tests/test_recurrent_models.py tests/test_module_contracts.py -q Expected: PASS

Task 6: Implement RecurrentPPO under rl_training.contrib

Files: - Create: src/rl_training/data/recurrent_rollout_buffer.py - Create: src/rl_training/models/recurrent/__init__.py - Create: src/rl_training/models/recurrent/lstm_actor_critic.py - Create: src/rl_training/contrib/__init__.py - Create: src/rl_training/contrib/api.py - Create: src/rl_training/contrib/recurrent_ppo.py - Modify: src/rl_training/models/__init__.py

Step 1: Write minimal implementation

Implement: - RecurrentRolloutBuffer - LSTMActorCritic - a managed RecurrentPPO API class in rl_training.contrib - enough serialization support for hidden-state-aware training checkpoints

Step 2: Run focused tests

Run: PYTHONPATH=src pytest tests/test_recurrent_rollout_buffer.py tests/test_recurrent_models.py tests/test_recurrent_ppo_update.py -q Expected: PASS

Task 7: Add the recurrent PPO trainer, registry wiring, and checkpoint flows

Files: - Create: src/rl_training/runtime/recurrent_ppo_trainer.py - Modify: src/rl_training/experiment/registry.py - Modify: src/rl_training/cli.py - Modify: src/rl_training/runtime/workflows.py - Modify: tests/test_public_api.py - Modify: tests/test_package_api_exports.py - Modify: tests/test_checkpoint_workflows.py - Modify: tests/test_experiment_manager.py - Create: tests/test_recurrent_ppo_trainer_smoke.py - Create: tests/test_recurrent_ppo_reference_script.py

Step 1: Write the failing tests

Add end-to-end coverage for: - algo: recurrent_ppo train / eval / resume - rl_training.contrib.RecurrentPPO(...).learn() - checkpoint load, prediction, and evaluation - example-script smoke execution

Step 2: Run tests to verify they fail

Run: PYTHONPATH=src pytest tests/test_recurrent_ppo_trainer_smoke.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_checkpoint_workflows.py tests/test_experiment_manager.py tests/test_recurrent_ppo_reference_script.py -q Expected: FAIL because recurrent PPO is not yet wired into trainer, registry, and public surfaces.

Step 3: Write minimal implementation

Implement: - train_recurrent_ppo() - registry loading, evaluation, and prediction helpers - contrib exports and example support - CLI handling that keeps recurrent_ppo explicit rather than silently overloading vanilla PPO

Step 4: Run tests to verify they pass

Run: PYTHONPATH=src pytest tests/test_recurrent_ppo_trainer_smoke.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_checkpoint_workflows.py tests/test_experiment_manager.py tests/test_recurrent_ppo_reference_script.py -q Expected: PASS

Task 8: Add zoo/ presets, benchmark manifests, and docs

Files: - Create: zoo/README.md - Create: zoo/atari/benchmark.yaml - Create: zoo/atari/dqn_breakout.yaml - Create: zoo/atari/ppo_breakout.yaml - Create: zoo/atari/recurrent_ppo_breakout.yaml - Create: scripts/benchmark_zoo.py - Create: tests/test_zoo_presets.py - Modify: README.md

Step 1: Write the failing tests

Add tests that assert: - zoo preset files are present and parse cleanly - benchmark manifests point to real config paths - README references the zoo layout and run commands

Step 2: Run tests to verify they fail

Run: PYTHONPATH=src pytest tests/test_zoo_presets.py -q Expected: FAIL because the zoo layer does not exist yet.

Step 3: Write minimal implementation

Create a minimal but real zoo layer with named presets, a benchmark manifest, and one script that can enumerate or launch preset runs.

Step 4: Run tests to verify they pass

Run: PYTHONPATH=src pytest tests/test_zoo_presets.py -q Expected: PASS

Task 9: Add installable CLI metadata and package docs polish

Files: - Modify: pyproject.toml - Modify: README.md - Create: tests/test_package_smoke.py

Step 1: Write the failing tests

Add or extend package smoke coverage so it asserts: - a project.scripts entry exposes the CLI - README explains train / eval / resume for both core and contrib flows - package metadata still imports cleanly

Step 2: Run tests to verify they fail

Run: PYTHONPATH=src pytest tests/test_package_smoke.py tests/test_cli.py -q Expected: FAIL because installable CLI entrypoints and updated docs are missing.

Step 3: Write minimal implementation

Add a console script such as rl-training and update the README so new users can discover the stable core path, the contrib path, and the zoo presets.

Step 4: Run tests to verify they pass

Run: PYTHONPATH=src pytest tests/test_package_smoke.py tests/test_cli.py -q Expected: PASS

Task 10: Final verification

Files: - Verify only

Step 1: Run targeted Phase 1 verification

Run: PYTHONPATH=src pytest tests/test_atari_envs.py tests/test_nature_cnn.py tests/test_atari_dqn_trainer_smoke.py tests/test_atari_ppo_trainer_smoke.py tests/test_atari_reference_scripts.py tests/test_recurrent_rollout_buffer.py tests/test_recurrent_models.py tests/test_recurrent_ppo_update.py tests/test_recurrent_ppo_trainer_smoke.py tests/test_recurrent_ppo_reference_script.py tests/test_zoo_presets.py tests/test_cli.py tests/test_public_api.py tests/test_package_api_exports.py tests/test_checkpoint_workflows.py tests/test_experiment_manager.py tests/test_package_smoke.py -q Expected: PASS

Step 2: Run full verification

Run: PYTHONPATH=src pytest -q Expected: PASS

Step 3: Commit in logical slices

Use small commits that match the phase boundaries above, for example:

git add tests/test_atari_envs.py tests/test_nature_cnn.py src/rl_training/envs/atari.py src/rl_training/envs/factory.py src/rl_training/models/cnn src/rl_training/models/__init__.py
git commit -m "feat: add atari env pipeline and nature cnn"

git add tests/test_atari_dqn_trainer_smoke.py tests/test_atari_ppo_trainer_smoke.py src/rl_training/runtime/dqn_trainer.py src/rl_training/runtime/ppo_trainer.py src/rl_training/experiment/registry.py
git commit -m "feat: enable atari training for dqn and ppo"

git add tests/test_recurrent_rollout_buffer.py tests/test_recurrent_models.py tests/test_recurrent_ppo_update.py src/rl_training/data/recurrent_rollout_buffer.py src/rl_training/models/recurrent src/rl_training/contrib
git commit -m "feat: add recurrent ppo primitives"

git add tests/test_recurrent_ppo_trainer_smoke.py tests/test_recurrent_ppo_reference_script.py src/rl_training/runtime/recurrent_ppo_trainer.py src/rl_training/experiment/registry.py src/rl_training/cli.py src/rl_training/runtime/workflows.py
git commit -m "feat: wire recurrent ppo into contrib workflows"

git add zoo scripts/benchmark_zoo.py tests/test_zoo_presets.py README.md pyproject.toml
git commit -m "feat: add zoo presets and package entrypoints"