作者:昇腾实战派

1. 背景概述

在大语言模型的强化学习训练过程中,PPO(Proximal Policy Optimization)算法因其稳定性和效率而被广泛应用。在昇腾Atlas 800T A2上基于Align-Anything框架进行Qwen2.5-0.5B模型PPO训练时遇到性能瓶颈问题,经过分析,整体性能影响主要受到Host侧下发约束,本文将详细介绍完整的性能分析与优化过程。

align-anything开源仓:https://github.com/PKU-Alignment/align-anything/tree/main

2. 性能分析与优化过程

2.1 性能数据采集与分析

针对pytorch场景,可通过以下方式采集性能数据:https://www.hiascend.com/document/detail/zh/mindstudio/81RC1/msquickstart/atlasquick_train_0018.html

设置with_stack=True可获取代码堆栈信息,便于定位问题算子。为避免对性能产生过大影响,初步分析时不开启堆栈采集。

获取到prof后,使用性能工具进行对比分析,工具使用方式参考:https://gitcode.com/Ascend/mstt/tree/master/profiler/msprof_analyze/#%E6%80%A7%E8%83%BD%E5%AF%B9%E6%AF%94%EF%BC%88compare%EF%BC%89%E5%AD%90%E5%8A%9F%E8%83%BD

采集完成的prof可以通过MindStudio-Insight工具进行可视化,便于分析。工具下载方式:https://www.hiascend.com/developer/download/community/result?module=pt+sto+cann

在这里插入图片描述

使用MindStudio-Insight工具可视化timeline后发现,NPU操作之间存在明显空隙,pytorch→CANN→Hardware连线垂直,表明存在Host Bound问题。

2.2 常规优化方案

首先尝试常规优化手段:

  • 二级流水优化:将部分算子适配任务迁移至二级流水,使两级流水负载更均衡,减少dequeue唤醒时间
  • 绑核优化、禁用私有格式、关闭JIT编译,使用毕昇编译优化进行优化等

二级流水优化能够带来明显效果,整体优化后端到端耗时减少15%。

2.3 深入分析流水线问题

以上常用方案使能后优化效果不明显,free time过长,怀疑是否是cpu出现核间抢占、流水中断等问题,需要采集执行过程中的ftrace文件。
在这里插入图片描述
在这里插入图片描述

可以看出acl_thread时间很短,表明下发任务耗时正常,无核间抢占或流水中断问题,优化重点应减少下发次数。

2.4 代码逻辑优化

2.4.1 融合算子替换

在rollout阶段使用原生transformers推理效率较低,循环生成算子下发频繁。

在这里插入图片描述

通过使能融合算子进行优化:
transformers/models/qwen2/modeling_qwen2.py中:

  • 替换npu_rotary_mul融合算子
  • 替换npu_rms_norm融合算子

在环境中transformers/models/qwen2/modeling_qwen2.py 修改,transformers版本为4.53.2

npu_rotary_mul
在这里插入图片描述

npu_rms_norm
在这里插入图片描述

使能融合算子后端到端可得到一定提升。

2.4.2 GAE计算优化

使能融合算子后,GAE部分耗时久,free time 也较多,利用堆栈定位到定位到compute_bi_level_gae_advantage_return函数存在性能瓶颈。
在这里插入图片描述

优化思路

将原版「按样本逐 batch for 循环 + Python 原生判断 + 标量级索引取值」的串行低效逻辑,重构为「PyTorch 全张量化并行计算 + 规整 padding 对齐 + 批量掩码筛选 + 向量化索引赋值」的并行高效逻辑​。

核心改进

  • 用张量的矩阵式批量运算替代Python的逐元素串行循环
  • 规避原生for循环的低效开销与CPU侧的Host Bound问题
  • 在不改变双层 GAE 的数学计算逻辑、保证结果完全一致的前提下,提升计算效率、减少 device 等待

最终测试端到端会有较大提升。

2.4.3 超参数调整

修改配置文件中的es_manager里的超参数,该参数会控制rollout的batch size以及训练算梯度的batch size,增大batch size以增加device侧计算量,充分利用算力
在这里插入图片描述

由于调整超参,总体耗时会增加,但free time占比下降到40%

2.4.4 显存优化

调整超参后,采集性能数据与内存数据,发现在反向会出现尖刺导致显存打满从而触发内存重整,影响性能,通过堆栈定位此处显存尖刺的位置。
在这里插入图片描述
在这里插入图片描述

定位到代码位置后,发现未释放中间张量,中间结果会持续占用显存直至该作用域结束,推高显存峰值。
优化方式:通过contiguous()方法重整切片后张量的内存布局,消除非连续张量的冗余内存占用;
定义独立中间变量并在计算完成后通过del显式释放,及时回收显存资源;

在这里插入图片描述

总结

本次优化实践表明,通过系统的性能分析、常规优化与代码级深度优化相结合,能够有效解决昇腾平台在复杂训练场景下的性能瓶颈问题。

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐