Fully Async 代码走读
作者:昇腾实战派
知识地图:https://blog.csdn.net/Lumos_Lovegood/article/details/161455142
一、简介
Fully Async Policy 是VeRL大模型强化学习开源仓库的一项异步训练特性,通过将训练和推理完全解耦为Trainer和Rollouter,达成训推并行异步的训练架构。
昇腾设备已支持VeRL大模型强化学习框架
如下图所示,fully async支持四种不同程度的异步训练:

a. on policy pipeline - 只是异卡训练,作为后面几种模式的 baseline,实际意义不大
b. stream off policy pipeline - 将一次参数同步间隔拆分为多个minibatch,推理推完一个minibatch就进行一次参数更新
c. async stream pipeline with staleness samples - 加入staleness(新鲜度)概念,顾名思义就是允许 Rollouter一直生成样本,模型可以使用上一次参数同步之前Rollouter生成的样本,用staleness来控制可以使用样本的比例
d. async stream pipeline with partial rollout - 加入partial rollout,允许本轮未推完的样本在参数同步后接着推(也就是说允许一条样本由两个不同参数的模型推理)
优势在于:
- 资源隔离与优化:Rollouter 和 Trainer 可以使用独立的计算资源,从而可以根据各自的负载进行精细化的资源分配,避免了同步训练中因一方等待另一方而造成的资源闲置。
- 并行化与效率提升:Trainer 在进行模型训练的同时,Rollouter 可以不间断地生成新的训练样本,两者并行工作,极大地重叠了时间消耗,从而缩短了端到端的训练时间。
- Stalness & partial rollout:两个特性进一步填充了训练中的空闲时间
二、代码总览
主要组件:

- Rollouter (fully_async_rollouter.py): 负责与环境交互,生成训练数据,并将其放入消息队列。
- Trainer (fully_async_trainer.py): 从消息队列中获取数据,执行 PPO 算法的训练步骤。
- MessageQueue (message_queue.py): 作为 Rollouter 和 Trainer 之间的缓冲池,解耦两者的生产和消费速度。
- CheckpointEngineManager: 基于 NCCL/HCCL 的参数同步器,负责在 Trainer 和 Rollouter 之间同步模型权重,
然后来总览一下代码文件:
fully_async_policy/
├── README.md # 英文说明文档
├── README_zh.md # 中文说明文档
├── fully_async_main.py # 主入口文件
├── fully_async_rollouter.py # Rollouter 实现
├── fully_async_trainer.py # Trainer 实现
├── message_queue.py # 消息队列实现
├── detach_utils.py # 工具函数和数据结构
├── agent_loop/
│ ├── __init__.py
│ └── agent_loop.py # 支持 partial rollout 的 AgentLoop
├── config/
│ ├── fully_async_ppo_megatron_trainer.yaml
│ └── fully_async_ppo_trainer.yaml # 配置文件
├── shell/ # 启动脚本
│ ├── dapo_30b_a3b_base_math_fsdp.sh
│ ├── dapo_7b_async_retool.sh
│ ├── ... (其他启动脚本)
│ └── runtime_env.yaml
└── unittest/
└── simple_streaming_demo.py
可以通过下图简单对应关键类和整体架构:

VeRL中通过不同的TaskRunner来走向不同的训练方式,TaskRunner中完成了各组件的初始化和最上层的训练循环执行;Rollouter和Trainer中分别定义了训练和推理各自的流程,通过fit()方法来执行,数据通过MessageQueueClient来传递,参数同步则通过CheckpointEngineManager完成。
接下来详细介绍一下各组件的代码构成
三、组件构成
1. Rollouter

Rollouter的fit()主流程包含一个主函数_stream_generate_main()和监视器_async_monitor_loop();
• _stream_generate_main()中获取prompt的方法_feed_samples()和执行generate的_processor_worker()异步执行:
○ _feed_samples()负责持续获取prompt batch放入pending queue中,直到所有batch都被消耗完;
○ _processor_worker()则会在一开始用_should_pause_generation()判断生成的样本是否已满足请求的数量或超过staleness阈值的旧样本,如果通过则持续调用_process_single_sample_streaming()来进行当前batch的推理,不通过则pause等待monitor唤醒;如果开启了partial rollout,还会对上一个batch推理中断队列的样本进行续推,最终将推理完成的样本放入message queue中
• _async_monitor_loop()则以10秒的间隔轮询_should_pause_generation()来唤醒_processor_worker();以60秒的间隔打印统计信息
2. Trainer

Trainer的逻辑比较直观:
- 首先执行_get_samples_from_queue 从MessageQueue中获取样本,当样本请求到None时中止循环;
- 然后由_fit_step中依次执行_fit_compute_reward()、_fit_compute_log_prob()…,进行训练流程
- 当trigger_parameter_sync_step达到目标步数后,进行与Rollouter唯一一次交互 - 参数同步
3. MessageQueue

message_queue.py 实现了一个的异步消息队列系统,用于 Rollouter 和 Trainer 之间的训练样本传递。主要包含两个类:
MessageQueue:基于 deque 实现的异步消息队列,使用 asyncio.Lock 和 Condition 实现线程安全,当队列满时自动丢弃最旧的样本
MessageQueueClient:提供异步和同步两种调用方式,封装 Ray remote 调用
4. CheckpointEngine
VeRL中Checkpoint engine NPU的实现在hccl_checkponit_engine.py文件中,包含了init_process_group()、recive_weights()、send_weights()等方法,主要用于Rollouter和Trainer间的参数传递。
四、完整时间线

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐



所有评论(0)