跳转至

C51-DQN Implementation Plan

Goal: Add c51_dqn (Categorical / Distributional DQN) as a first-class discrete-action algorithm with: - fixed atom support [v_min, v_max] with num_atoms - categorical distribution projection for Bellman updates - cross-entropy loss against the projected target distribution

Architecture: Keep train_dqn as the trainer entrypoint to maximize reuse (same replay buffer + epsilon-greedy loop). Add a dedicated distributional Q-network (MLPC51QNetwork) and algorithm (C51DQN). Wire checkpoint load/eval/predict through the registry so CLI and managed API work unchanged.


Task 1: Add failing coverage

Files: - Create: tests/test_c51_dqn_update.py - Modify: tests/test_dqn_trainer_smoke.py - Modify: tests/test_cli.py - Modify: tests/test_public_api.py - Modify: tests/test_package_api_exports.py - Create: tests/test_c51_dqn_reference_script.py

Expected: FAIL until the model, algorithm, and registry wiring exist.


Task 2: Implement model + algorithm

Files: - Create: src/rl_training/models/mlp_c51_q_network.py - Create: src/rl_training/algorithms/c51_dqn.py

Notes: - Network outputs logits with shape (batch, action_dim, num_atoms). - Action selection uses expected value E[Z] under the categorical distribution.


Task 3: Wire trainer + registry + API

Files: - Modify: src/rl_training/runtime/dqn_trainer.py - Modify: src/rl_training/experiment/registry.py - Modify: src/rl_training/api/algorithms.py - Modify: src/rl_training/api/__init__.py - Modify: src/rl_training/algorithms/__init__.py - Modify: src/rl_training/__init__.py


Task 4: Add runnable example + config

Files: - Create: examples/c51_dqn_cartpole_reference.py - Create: configs/c51_dqn/cartpole.yaml


Verification

Run: pytest -q