训练流程¶
本页详细介绍 AxiomRL 的完整训练管线——从 YAML 配置文件到训练产物(检查点、日志、评估结果)的全过程。
总览¶
AxiomRL 的训练流程遵循以下管线:
graph LR
A[YAML 配置] --> B[TrainConfig]
B --> C[Algorithm]
C --> D[训练循环]
D --> E[Collector]
E --> F[Buffer]
F --> D
D --> G[Evaluator]
D --> H[Checkpoint]
D --> I[TensorBoard] 整个流程可概括为:配置 → 初始化 → 循环(收集 → 缓存 → 更新 → 评估 → 记录)。每个阶段的行为均由 TrainConfig 中的字段控制。
配置加载¶
训练从加载配置开始。AxiomRL 支持多种配置来源,最终统一解析为 TrainConfig 数据类。
YAML 配置文件¶
最常见的方式是编写 YAML 文件:
algo: PPO
env_id: CartPole-v1
seed: 42
total_timesteps: 100000
output_dir: ./runs/cartpole_ppo
num_envs: 4
eval_episodes: 10
TrainConfig 数据类¶
所有配置最终转化为 TrainConfig 实例。以下是其完整字段:
| 字段 | 类型 | 默认值 | 说明 |
|---|---|---|---|
algo | str | 必填 | 算法名称(如 "PPO"、"SAC") |
env_id | str | 必填 | 环境 ID(如 "CartPole-v1") |
seed | int | 必填 | 随机种子,保证可复现性 |
total_timesteps | int | 必填 | 总训练步数 |
output_dir | str | 必填 | 输出目录路径 |
execution_backend | str | "local_sync" | 执行后端 |
device | str | "auto" | 计算设备("auto"、"cpu"、"cuda") |
num_envs | int | 1 | 并行环境数量 |
eval_episodes | int | 5 | 每次评估的回合数 |
log_interval | int | 1 | 日志记录间隔(训练迭代数) |
checkpoint_interval | int | 1 | 检查点保存间隔(训练迭代数) |
tags | tuple | () | 实验标签,用于组织和筛选 |
benchmark | dict | {} | 基准测试相关配置 |
algo_kwargs | dict | {} | 传递给算法的额外参数 |
env_kwargs | dict | {} | 传递给环境的额外参数 |
预设与 CLI 覆盖¶
配置的优先级从低到高为:
- Zoo 预设 — 预定义的最优超参数组合
- YAML 文件 — 用户自定义配置
- CLI 覆盖 — 命令行参数,优先级最高
算法初始化¶
配置加载完成后,框架根据 algo 字段从算法注册表中查找对应的算法类,并完成初始化。
注册表查找¶
AxiomRL 维护一个全局算法注册表。algo 字段的值(如 "PPO")会被映射到对应的算法类(如 rl_training.core.PPO)。查找顺序为:
- Core 层
- Experimental 层
- Contrib 层
策略与网络创建¶
算法初始化时会自动完成以下步骤:
- 根据环境的观测空间和动作空间推断网络结构
- 创建策略网络(Policy Network)和价值网络(Value Network)
- 将网络部署到
device指定的计算设备上 - 应用
algo_kwargs中的自定义超参数
数据收集¶
Collector(数据采集器)负责与环境交互,收集训练所需的经验数据。
Collector 工作机制¶
Collector 在每次采集步骤中执行以下操作:
- 将当前观测输入策略网络,获取动作
- 在环境中执行动作,获取下一观测、奖励和终止信号
- 将转移元组
(obs, action, reward, next_obs, done)存入 Buffer
向量化环境¶
通过 num_envs 参数可以启用多环境并行采集,显著提升数据吞吐量:
性能建议
对于 On-Policy 算法(如 PPO、A2C),增加 num_envs 可以线性提升采集速度。对于 Off-Policy 算法(如 SAC、TD3),通常 1-4 个环境即可满足需求。
经验缓存¶
Buffer(经验缓存)存储 Collector 采集的数据,供训练更新使用。AxiomRL 根据算法类型自动选择合适的缓存策略。
Off-Policy:ReplayBuffer¶
Off-Policy 算法(DQN、SAC、TD3、CQL、IQL、DiscreteSAC)使用 ReplayBuffer:
- 固定容量的环形缓冲区
- 支持均匀随机采样
- 数据可被多次复用
- 支持优先级经验回放(PER)等变体
On-Policy:Rollout Buffer¶
On-Policy 算法(PPO、A2C、TRPO)使用 Rollout Buffer:
- 存储完整的 rollout 数据
- 每次策略更新后清空
- 数据仅使用一次(或少量几个 epoch)
- 包含 GAE(Generalized Advantage Estimation)计算所需的额外信息
训练更新¶
训练循环的核心是策略和价值网络的梯度更新。
更新流程¶
每次训练迭代包含以下步骤:
- 从 Buffer 中采样一个或多个批次的数据
- 计算损失函数(策略损失、价值损失、熵正则化等)
- 执行反向传播,更新网络参数
- 记录训练指标(损失值、梯度范数等)
自定义超参数¶
通过 algo_kwargs 可以精细控制训练更新行为:
algo_kwargs:
learning_rate: 0.0003 # 学习率
batch_size: 256 # 批量大小
gamma: 0.99 # 折扣因子
tau: 0.005 # 目标网络软更新系数
gradient_steps: 1 # 每次采集后的梯度更新步数
评估¶
评估模块定期测试当前策略的性能,并跟踪最优模型。
评估机制¶
评估由 eval_episodes 字段控制:
每次评估会:
- 使用确定性策略(无探索噪声)运行指定数量的回合
- 记录平均回报、标准差、成功率等指标
- 与历史最优结果对比
最优模型追踪¶
评估器自动追踪历史最优性能。当某次评估的平均回报超过历史最佳时,框架会将当前模型保存为 best.pt。
检查点¶
检查点系统负责定期保存训练状态,支持断点续训和模型部署。
保存策略¶
检查点的保存频率由 checkpoint_interval 控制:
检查点文件¶
保存的检查点包含以下文件:
| 文件 | 说明 |
|---|---|
best.pt | 评估性能最优的模型 |
step_N.pt | 第 N 步的定期检查点 |
每个检查点包含完整的训练状态:网络权重、优化器状态、训练步数、随机数生成器状态等,确保断点续训的完全确定性。
断点续训
AxiomRL 的检查点机制保证了续训的确定性——从检查点恢复后继续训练,将得到与不中断训练完全一致的结果。
日志¶
AxiomRL 使用 TensorBoard 作为主要的日志后端,记录训练过程中的各项指标。
日志配置¶
日志记录频率由 log_interval 控制:
记录的指标¶
TensorBoard 日志包含以下典型指标:
- 训练指标 — 策略损失、价值损失、熵、学习率
- 采集指标 — 每回合奖励、回合长度、采集速度(FPS)
- 评估指标 — 平均评估回报、成功率
- 系统指标 — GPU 显存使用、训练耗时
查看日志¶
output_dir 字段指定了所有训练产物(检查点、日志、评估结果)的根目录。