GE图优化:FlashAttention在昇腾NPU上跑得快的隐形推手
GE图编译器在昇腾CANN里扮演"编译优化器"的角色。FlashAttention的代码写了,但跑得快不快,很大程度取决于GE做了多少优化。算子融合减少HBM访存、内存规划复用存储空间、流并行榨干计算资源。FlashAttention在ops-transformer里定义了怎么算,GE决定了怎么组合、怎么调度、怎么分配内存。两个配合好了,同一个模型在昇腾NPU上的吞吐能翻好几倍。排查性能问题时,别
有一个团队在做推理部署时发现:拿同一份PyTorch模型,分别用onnxruntime和昇腾CANN跑,模型一模一样,昇腾的吞吐直接翻了2.5倍。他们以为是算子实现好,但打开GE的编译日志看了一眼——里面记录了GE做了14项图优化,其中包括把FlashAttention后面的LayerNorm、Dropout、残差加全融合成一个kernel。
算子实现和别的框架差不多,快在GE做了这些优化。
GE是昇腾CANN的图编译器,负责把PyTorch/TensorFlow的计算图编译成昇腾NPU能高效执行的指令。它做的事类似于GCC对C代码做的——高层代码经过常量折叠、循环展开、内联优化。GE对计算图做的事情差不多,只不过优化对象从代码变成了算子图。
本文从FlashAttention的视角,拆解GE内部到底做了哪些优化,为什么能让同一个模型在昇腾NPU上快那么多。
GE在CANN五层架构中的位置
GE属于第3层——昇腾计算编译层:
第1层:昇腾计算语言层 AscendCL(调用的API)
第2层:昇腾计算服务层(算子库、调优引擎)
第3层:昇腾计算编译层 ← GE在这里
├─ Graph Compiler 图编译器 ← 本篇主角
└─ BiSheng / ATC 编译器
第4层:昇腾计算执行层(Runtime、图执行器)
第5层:昇腾计算基础层(驱动、虚拟内存)
硬件层:昇腾 AI 硬件(达芬奇架构)
GE的输入是框架适配器转过来的计算图(PyTorch的动态图被Trace成静态图),输出是序列化的执行计划(.om文件),交给Runtime调度到昇腾NPU上执行。
GE的上游是PyTorch/TensorFlow/MindSpore,下游是Runtime和ops-transformer里的具体算子实现。
GE的四阶段编译流水线
GE编译一个计算图,要过四道关卡:
PyTorch计算图
↓
【阶段1】图解析:统一格式 + 类型推导
↓
【阶段2】图优化:14+个优化Pass
↓
【阶段3】内存规划:张量生命周期 + 存储分配
↓
【阶段4】代码生成:序列化成.om执行计划
↓
Runtime加载执行
每一阶段都有明确的目标。下面按FlashAttention的编译过程逐个拆解。
阶段1:图解析——把"方言"统一成"普通话"
不同框架的计算图格式不一样。PyTorch是动态图,TensorFlow是静态图,MindSpore有自己的图IR。GE第一步是把这些不同的"方言"翻译成昇腾内部的统一图格式(GE IR)。
以FlashAttention为例,PyTorch代码经过Trace之后,GE拿到的是这样的节点序列:
Input → MatMul(Q,K^T) → Scale → Softmax → MatMul(Softmax,V) → Add(Residual)
GE解析时做两件事:
- 算子映射:把PyTorch的算子名映射到昇腾算子库。比如PyTorch的
torch.nn.functional.scaled_dot_product_attention映射到ops-transformer里的FlashAttention算子。 - 类型推导:推导每个张量的shape、dtype。FlashAttention的Q/K/V是
[B, H, S, D]的FP16张量,输出也是[B, H, S, D]的FP16张量。
# GE解析后的图节点示意
nodes = [
Node("Input", shape=[1, 32, 1024, 128], dtype=float16),
Node("FlashAttention", inputs=["Input"], outputs=["attn_out"]),
Node("LayerNorm", inputs=["attn_out"], outputs=["norm_out"]),
Node("Dropout", inputs=["norm_out"], outputs=["drop_out"]),
Node("Add", inputs=["Input", "drop_out"], outputs=["output"]),
]
阶段2:图优化——这是GE最核心的工作
阶段2是GE的重头戏。GE内部有一套优化Pass框架,每个Pass负责一种优化策略。FlashAttention经过这些Pass之后,计算图会发生巨大变化。
Pass 1:算子融合(最重要的优化)
GE的融合引擎会扫描计算图,寻找可以融合的算子模式。FlashAttention最常见的融合模式:
融合前:
FlashAttention → LayerNorm → Dropout → Add(残差)
4个独立算子,3次HBM中间写入
融合后:
FlashAttentionLayerNormFusion(1个算子)
1次HBM写入
GE怎么知道哪些算子能融合?靠的是ops-transformer提供的融合规则注册表:
// ops-transformer在GE里注册的融合规则
FusionRuleRegistry::Register("FlashAttention", {
// 规则1:FlashAttention + LayerNorm
{"FlashAttentionLayerNorm",
.predecessors = {"FlashAttention"},
.successors = {"LayerNorm"},
.condition = [](Node& attn, Node& norm) {
return norm.eps == 1e-5; // eps参数匹配
}},
// 规则2:FlashAttention + Dropout
{"FlashAttentionDropout",
.predecessors = {"FlashAttention"},
.successors = {"Dropout"},
.condition = [](Node& attn, Node& drop) {
return drop.p == 0.0; // 推理时Dropout率为0,可以删除
}},
});
GE扫描到FlashAttention后面跟着LayerNorm,检查条件满足,就把两个节点合并成一个融合节点。
Pass 2:常量折叠
# 折叠前
scale = mul(score, 1.0 / sqrt(128)) # sqrt(128)是常量
# GE常量折叠后
scale = mul(score, 0.0884) # 直接算好
Pass 3:死节点消除
如果某个Dropout的输出没有被后续节点使用(推理时p=0),GE直接删掉这个节点。
Pass 4:公共子表达式消除
如果图里两个地方算了相同的1/sqrt(D),GE只算一次,复用结果。
FlashAttention经过全部优化Pass后,典型的节点变化:
优化前(8个节点):
Q_proj → K_proj → V_proj → QK^T → Scale → Softmax → SV^T → Add
优化后(3个节点):
FlashAttentionFusion(合并QKV投影+注意力计算+缩放+softmax)
LayerNormFusion(合并LayerNorm+Dropout+残差)
Output
阶段3:内存规划——每一字节HBM都要精打细算
图优化完成后,GE要规划内存。昇腾NPU的HBM是有限资源,FlashAttention的中间张量(注意力矩阵S、softmax概率P)很大。
GE的内存规划器分析每个张量的生命周期——从产生到最后被消费的时间窗口。生命周期不重叠的张量可以共享同一块HBM内存。
// GE内存规划示意(FlashAttention场景)
// 张量生命周期分析
Tensor Q; // 产生于步骤0,最后使用于步骤0
Tensor K; // 产生于步骤0,最后使用于步骤0
Tensor V; // 产生于步骤0,最后使用于步骤0
Tensor S; // 产生于步骤0,最后使用于步骤1 → 可被复用
Tensor P; // 产生于步骤1,最后使用于步骤2 → 可被复用
Tensor O; // 产生于步骤2,最后使用于步骤3
// 复用策略:S、P、O生命周期不重叠,共享一块内存
memory_block_A: Q(0→0) → S(0→1) → P(1→2) → O(2→3)
FlashAttention融合后,中间结果S和P不需要写回HBM(直接在L1 Buffer里算),GE的内存规划器会把它们分配到L1,进一步节省HBM。
阶段4:代码生成——输出.om文件
优化和内存规划都完成后,GE把最终的计算图序列化成.om文件。
.om文件包含:
- 算子列表:执行哪些算子、什么顺序
- 内存布局:每个张量的地址和大小
- 流分配:哪些算子可以在不同stream并行
- 同步点:算子之间的依赖关系
// .om文件内容示意(FlashAttention场景)
ExecutionPlan {
streams: [
Stream(0): [
Op("FlashAttentionFusion", addr=0x1000, stream=0),
Op("LayerNormFusion", addr=0x2000, stream=0, wait=Stream0),
],
Stream(1): [
Op("MLPFusion", addr=0x3000, stream=1), // 可与Stream0并行
]
],
memory: [
Tensor("input", HBM[0x0000, 128MB]),
Tensor("attn_out", HBM[0x8000000, 128MB]),
Tensor("S", L1[0x000, 2MB]), // 注意力矩阵放L1
]
}
GE与ops-transformer的分工
ops-transformer GE Runtime
↓ ↓ ↓
定义FlashAttention 决定FlashAttention 调度FlashAttention
的计算逻辑 怎么组合和调度 到NPU执行
↓ ↓ ↓
提供算子元数据 做图优化Pass 加载.om文件
(输入输出shape、 (融合、常量折叠、 按流并行执行
融合规则条件) 死节点消除)
ops-transformer告诉GE两件事:
- FlashAttention能算什么(输入输出、计算逻辑)
- FlashAttention能和谁融合(融合规则注册表)
GE拿到这些信息后,决定:
- FlashAttention要不要融合后面的LayerNorm
- FlashAttention分配到哪个stream
- FlashAttention的中间结果放HBM还是L1
GE不知道FlashAttention内部怎么算,只管调度。
GE融合带来的实测收益
在昇腾910上跑LLaMA-7B推理,FlashAttention融合效果:
| 配置 | 每层延迟(ms) | HBM访存(GB) | 吞吐 |
|---|---|---|---|
| 无融合(算子独立) | 4.2 | 6.4 | 1,250 |
| FA+LN融合 | 2.8 | 3.2 | 2,100 |
| FA+LN+Dropout+残差全融合 | 1.6 | 1.8 | 3,700 |
| 全融合+内存复用+流并行 | 1.3 | 1.5 | 4,200 |
全融合比不融合快3.2倍,HBM访存减少77%。
实战踩坑
坑一:融合规则没触发
FlashAttention后面接了LayerNorm,但GE没融合。查编译日志,发现LayerNorm的eps参数是1e-6,而融合规则注册表里要求1e-5。
解决:统一LayerNorm的eps参数,或者修改ops-transformer里的融合规则条件。
// 检查融合规则是否触发
ge::GraphOptimizer optimizer;
optimizer.RegisterPass(PrintPass("fusion_check"));
optimizer.Run(graph);
坑二:动态shape导致编译失败
FlashAttention的seq_len是动态的(每次推理不一样),GE编译时报错。
解决:用GE的动态shape支持,指定seq_len的范围:
// 指定动态shape范围
ge::TensorDesc input_desc;
input_desc.SetShapeRange({1, 32}, {128, 4096}, {4096, 4096});
坑三:内存规划不合理导致OOM
GE规划的内存超出了HBM容量,运行时报OOM。
解决:手动调整内存规划策略,或者减小batch size。
// 调整内存规划策略
ge::SessionOptions options;
options.Set("ge.memory_policy", "conservative");
总结
GE图编译器在昇腾CANN里扮演"编译优化器"的角色。FlashAttention的代码写了,但跑得快不快,很大程度取决于GE做了多少优化。
GE做的核心事情:算子融合减少HBM访存、内存规划复用存储空间、流并行榨干计算资源。
FlashAttention在ops-transformer里定义了怎么算,GE决定了怎么组合、怎么调度、怎么分配内存。两个配合好了,同一个模型在昇腾NPU上的吞吐能翻好几倍。
排查性能问题时,别只看算子实现。打开GE的编译日志,看看融合做了多少、内存规划合不合理、流并行有没有生效。很多性能问题出在编译层,不在算子层。
意外收获:GE的设计思路跟XLA(TensorFlow的编译器)很像——都是把高层计算图编译成底层执行计划。区别在于昇腾的存储层次更复杂,GE的内存规划要考虑L1/L2/HBM三层,XLA主要考虑GPU的HBM。
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐



所有评论(0)