作者:昇腾实战派
知识地图https://blog.csdn.net/Lumos_Lovegood/article/details/161455142

一、简介

Fully Async Policy 是VeRL大模型强化学习开源仓库的一项异步训练特性,通过将训练和推理完全解耦为Trainer和Rollouter,达成训推并行异步的训练架构。

昇腾设备已支持VeRL大模型强化学习框架

如下图所示,fully async支持四种不同程度的异步训练:

在这里插入图片描述

a. on policy pipeline - 只是异卡训练,作为后面几种模式的 baseline,实际意义不大

b. stream off policy pipeline - 将一次参数同步间隔拆分为多个minibatch,推理推完一个minibatch就进行一次参数更新

c. async stream pipeline with staleness samples - 加入staleness(新鲜度)概念,顾名思义就是允许 Rollouter一直生成样本,模型可以使用上一次参数同步之前Rollouter生成的样本,用staleness来控制可以使用样本的比例

d. async stream pipeline with partial rollout - 加入partial rollout,允许本轮未推完的样本在参数同步后接着推(也就是说允许一条样本由两个不同参数的模型推理)

优势在于:

  1. 资源隔离与优化:Rollouter 和 Trainer 可以使用独立的计算资源,从而可以根据各自的负载进行精细化的资源分配,避免了同步训练中因一方等待另一方而造成的资源闲置。
  2. 并行化与效率提升:Trainer 在进行模型训练的同时,Rollouter 可以不间断地生成新的训练样本,两者并行工作,极大地重叠了时间消耗,从而缩短了端到端的训练时间。
  3. Stalness & partial rollout:两个特性进一步填充了训练中的空闲时间

二、代码总览

主要组件:

在这里插入图片描述

  1. Rollouter (fully_async_rollouter.py): 负责与环境交互,生成训练数据,并将其放入消息队列。
  2. Trainer (fully_async_trainer.py): 从消息队列中获取数据,执行 PPO 算法的训练步骤。
  3. MessageQueue (message_queue.py): 作为 Rollouter 和 Trainer 之间的缓冲池,解耦两者的生产和消费速度。
  4. CheckpointEngineManager: 基于 NCCL/HCCL 的参数同步器,负责在 Trainer 和 Rollouter 之间同步模型权重,

然后来总览一下代码文件:

fully_async_policy/
├── README.md                          # 英文说明文档
├── README_zh.md                       # 中文说明文档
├── fully_async_main.py                # 主入口文件
├── fully_async_rollouter.py           # Rollouter 实现
├── fully_async_trainer.py             # Trainer 实现
├── message_queue.py                   # 消息队列实现
├── detach_utils.py                    # 工具函数和数据结构
├── agent_loop/
│   ├── __init__.py
│   └── agent_loop.py                  # 支持 partial rollout 的 AgentLoop
├── config/
│   ├── fully_async_ppo_megatron_trainer.yaml
│   └── fully_async_ppo_trainer.yaml  # 配置文件
├── shell/                              # 启动脚本
│   ├── dapo_30b_a3b_base_math_fsdp.sh
│   ├── dapo_7b_async_retool.sh
│   ├── ... (其他启动脚本)
│   └── runtime_env.yaml
└── unittest/
    └── simple_streaming_demo.py

可以通过下图简单对应关键类和整体架构:

在这里插入图片描述

VeRL中通过不同的TaskRunner来走向不同的训练方式,TaskRunner中完成了各组件的初始化和最上层的训练循环执行;Rollouter和Trainer中分别定义了训练和推理各自的流程,通过fit()方法来执行,数据通过MessageQueueClient来传递,参数同步则通过CheckpointEngineManager完成。

接下来详细介绍一下各组件的代码构成

三、组件构成

1. Rollouter

在这里插入图片描述

Rollouter的fit()主流程包含一个主函数_stream_generate_main()和监视器_async_monitor_loop();

• _stream_generate_main()中获取prompt的方法_feed_samples()和执行generate的_processor_worker()异步执行:

○ _feed_samples()负责持续获取prompt batch放入pending queue中,直到所有batch都被消耗完;

○ _processor_worker()则会在一开始用_should_pause_generation()判断生成的样本是否已满足请求的数量或超过staleness阈值的旧样本,如果通过则持续调用_process_single_sample_streaming()来进行当前batch的推理,不通过则pause等待monitor唤醒;如果开启了partial rollout,还会对上一个batch推理中断队列的样本进行续推,最终将推理完成的样本放入message queue中

• _async_monitor_loop()则以10秒的间隔轮询_should_pause_generation()来唤醒_processor_worker();以60秒的间隔打印统计信息

2. Trainer

在这里插入图片描述

Trainer的逻辑比较直观:

  1. 首先执行_get_samples_from_queue 从MessageQueue中获取样本,当样本请求到None时中止循环;
  2. 然后由_fit_step中依次执行_fit_compute_reward()、_fit_compute_log_prob()…,进行训练流程
  3. 当trigger_parameter_sync_step达到目标步数后,进行与Rollouter唯一一次交互 - 参数同步

3. MessageQueue

在这里插入图片描述

message_queue.py 实现了一个的异步消息队列系统,用于 Rollouter 和 Trainer 之间的训练样本传递。主要包含两个类:

MessageQueue:基于 deque 实现的异步消息队列,使用 asyncio.Lock 和 Condition 实现线程安全,当队列满时自动丢弃最旧的样本

MessageQueueClient:提供异步和同步两种调用方式,封装 Ray remote 调用

4. CheckpointEngine

VeRL中Checkpoint engine NPU的实现在hccl_checkponit_engine.py文件中,包含了init_process_group()、recive_weights()、send_weights()等方法,主要用于Rollouter和Trainer间的参数传递。

四、完整时间线

在这里插入图片描述

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐