C51-DQN Implementation Plan¶
Goal: Add c51_dqn (Categorical / Distributional DQN) as a first-class discrete-action algorithm with: - fixed atom support [v_min, v_max] with num_atoms - categorical distribution projection for Bellman updates - cross-entropy loss against the projected target distribution
Architecture: Keep train_dqn as the trainer entrypoint to maximize reuse (same replay buffer + epsilon-greedy loop). Add a dedicated distributional Q-network (MLPC51QNetwork) and algorithm (C51DQN). Wire checkpoint load/eval/predict through the registry so CLI and managed API work unchanged.
Task 1: Add failing coverage¶
Files: - Create: tests/test_c51_dqn_update.py - Modify: tests/test_dqn_trainer_smoke.py - Modify: tests/test_cli.py - Modify: tests/test_public_api.py - Modify: tests/test_package_api_exports.py - Create: tests/test_c51_dqn_reference_script.py
Expected: FAIL until the model, algorithm, and registry wiring exist.
Task 2: Implement model + algorithm¶
Files: - Create: src/rl_training/models/mlp_c51_q_network.py - Create: src/rl_training/algorithms/c51_dqn.py
Notes: - Network outputs logits with shape (batch, action_dim, num_atoms). - Action selection uses expected value E[Z] under the categorical distribution.
Task 3: Wire trainer + registry + API¶
Files: - Modify: src/rl_training/runtime/dqn_trainer.py - Modify: src/rl_training/experiment/registry.py - Modify: src/rl_training/api/algorithms.py - Modify: src/rl_training/api/__init__.py - Modify: src/rl_training/algorithms/__init__.py - Modify: src/rl_training/__init__.py
Task 4: Add runnable example + config¶
Files: - Create: examples/c51_dqn_cartpole_reference.py - Create: configs/c51_dqn/cartpole.yaml
Verification¶
Run: pytest -q