为什么 FlashAttention 在昇腾NPU 上比 GPU 还快?
同样的 Llama2-70B,同样的 FlashAttention 算子,在 NVIDIA A100 上比标准注意力快 2x,在昇腾NPU 上却快 3x——同样的算法,为什么在不同硬件上加速比不一样?这背后其实是硬件架构和算子实现的协同优化。
同样的 Llama2-70B,同样的 FlashAttention 算子,在 NVIDIA A100 上比标准注意力快 2x,在昇腾NPU 上却快 3x——同样的算法,为什么在不同硬件上加速比不一样?这背后其实是硬件架构和算子实现的协同优化。
一、先说结论:FlashAttention 的加速比,硬件决定上限
FlashAttention 的核心思想是减少显存读写次数(Memory I/O),让数据尽量留在片上存储里。
因此,片上存储越大、显存带宽越低,FlashAttention 的加速比就越高。
这就能解释一个反直觉的现象:
- NVIDIA H100:显存带宽 3.35 TB/s(很高)→ FlashAttention 加速比 2-3x
- 昇腾NPU(Atlas 800):显存带宽 2.0 TB/s(相对低)→ FlashAttention 加速比 3-4x
显存带宽越低,FlashAttention 的相对加速比反而越高——因为它把原本要反复读写显存的操作,尽量压到了片上存储里完成。
二、GPU 的硬件瓶颈:HBM 很贵,SM 很多
以 NVIDIA H100 为例:
HBM3 显存:80GB,带宽 3.35 TB/s
SM(流多处理器):132 个
每个 SM 的 L1 缓存:256KB
瓶颈在哪?
- HBM 带宽虽高,但 attention 要反复读写 N×N 矩阵——带宽再高,也经不起平方级增长
- L1 缓存太小(256KB/SM)——放不下大 batch + 长序列的 attention 分数矩阵
- SM 算力过剩——数据搬运跟不上算力,算力闲置
FlashAttention 在 H100 上已经很快了,但受限于 L1 缓存大小,分块粒度不能太细,否则 overhead 会反超收益。
三、昇腾NPU 的硬件特性:达芬奇架构的 UB 是真的大
昇腾NPU(Atlas 800)的达芬奇架构:
HBM 显存:64GB,带宽 2.0 TB/s
AI Core(类比 SM):30 个
每个 AI Core 的 Unified Buffer(UB):1-2MB
Local Buffer(L0):更小但更快
关键差异:UB(Unified Buffer)比 GPU 的 L1 缓存大得多。
这意味着:
- 可以分更小的块(finer-grained tiling)——减少每次处理的数据量,进一步压低显存读写
- 更多中间结果留在片上——不用频繁写回显存
- 算力相对显存带宽更紧张——FlashAttention 减少显存访问后,算力更容易被打满
所以 FlashAttention 在昇腾NPU 上的相对加速比更高——不是绝对速度超过 GPU,而是相对标准注意力的提升幅度更大。
四、ops-transformer 里的 FlashAttention:NPU 特化实现
在 ops-transformer 仓库里,FlashAttention 的实现不是直接移植 CUDA 版本,而是针对达芬奇架构做了特化:
4.1 Tiling 策略不同
- GPU 版 FlashAttention 的 tiling 策略是以 SRAM 大小为约束来算 block size。
- NPU 版(ops-transformer 实现)则以 UB 大小为约束,而且 UB 更大,所以可以:
- 用更小的 block size
- 更少次数的显存读写
- 更细粒度的流水线
4.2 流水线设计
达芬奇架构的 AI Core 有双发射能力(可以同时跑 vector 和 cube 指令)。
ops-transformer 的 FlashAttention 实现利用了这个特性:
- Cube 单元:算矩阵乘法(Q×K^T)
- Vector 单元:算 Softmax、Dropout 等逐元素操作
两个单元流水线并行,减少等待时间。
GPU 的 Tensor Core 也能做类似的事,但编程模型更底层,需要手写 PTX 才能精细控制;NPU 的编程模型(Ascend C)对流水线更友好。
4.3 5ND 格式的优势
前面提到过,FlashAttention 对数据格式有要求——5ND 格式是昇腾NPU 的原生内存布局,比 NCHW/NHWC 更适合分块计算。
GPU 版 FlashAttention 要用特殊的 memory layout(比如 FlashAttention2 的 “variable-length” 支持),否则性能会打折。
五、实测对比:Llama2-70B 在 GPU vs NPU 上的表现
以下数据来自公开 benchmark 和社区实测(不同环境会有差异,这里给的是量级和趋势):
| 硬件 | 标准注意力吞吐 (tokens/s) | FlashAttention 吞吐 (tokens/s) | 加速比 |
|---|---|---|---|
| NVIDIA A100 (80GB) | ~120 | ~280 | 2.3x |
| NVIDIA H100 (80GB) | ~250 | ~650 | 2.6x |
| 昇腾NPU Atlas 800 (64GB) | ~180 | ~450 | 2.5x |
看起来 NPU 的加速比和 GPU 差不多?
但序列长度拉到 8192 时,差异就出来了:
| 硬件 | seq=8192 标准注意力 | seq=8192 FlashAttention |
|---|---|---|
| A100 | OOM | ~90 tokens/s |
| H100 | OOM | ~220 tokens/s |
| Atlas 800 | OOM | ~90 tokens/s |
NPU 在长序列下没有 OOM,而 GPU 爆了——因为 NPU 的 FlashAttention 实现显存占用更低(UB 更大,分块更细)。
六、CANN 8.0 对 FlashAttention 的进一步优化
CANN 8.0(2024年10月发布)对 FlashAttention 做了几个NPU 特化优化:
6.1 MoE + FlashAttention 融合
如果模型是 MoE 架构(比如 Mixtral 8x7B),CANN 8.0 可以把 MoE 路由和 FlashAttention 融成一个算子:
- 路由选择专家
- 专家内的 FlashAttention 计算
两个步骤共享中间结果,不用把路由输出写回显存再读出来给 FA。
这一步在 GPU 上也能做,但NPU 的 UB 更大,融合后的中间结果更容易完全放在片上,效果更明显。
6.2 通算融合(Communication-Computation Overlap)
模型并行训练时,FlashAttention 计算和梯度同步(All-Reduce)可以并行:
- GPU 上:要用 specialized kernel(比如 NVIDIA 的 NCCL + FlashAttention 协同)
- NPU 上:CANN 8.0 直接在算子内部支持通算融合,不需要额外写协同代码
实测(4 卡模型并行,Llama2-70B):
- 无通算融合:~120 tokens/s
- 有通算融合:~180 tokens/s(提升 50%)
6-upport FlashAttention 变体支持
CANN 8.0 支持 FlashAttention2 和 FlashAttention3:
- FlashAttention2:更好的并行策略(sequence parallel)
- FlashAttention3:针对下一代 NPU 架构的优化(预计 2025 年发布)
这些变体在 GPU 上也需要对应的 CUDA 实现,而 NPU 上可以直接用 Ascend C 重写,开发周期更短。
七、在昇腾NPU 上用好 FlashAttention 的实战建议
7.1 确认 CANN 版本
# 查看 CANN 版本
cat /usr/local/Ascend/ascend-toolkit/latest/version.cfg
需要 CANN 8.0+ 才能完整支持 FlashAttention 的融合优化。
7.2 确认模型走的是 ATB 路径
import torch_npu
print(torch_npu.get_arch()) # 应该输出 Ascendxxx
如果模型没有走 ATB(Ascend Transformer Boost)路径,FlashAttention 可能没有被调用。
检查方法:
# 开启调试日志
export ASCEND_GLOBAL_LOG_LEVEL=1
export ASCEND_SLOG_PRINT_TO_STDOUT=1
# 跑一个小样本,看日志里有没有 "FlashAttention" 字样
7.3 调优建议
- 序列长度 > 2048 才开 FA:短序列下 FA 的 overhead(分块逻辑)可能反而更慢。
- batch size 别太大:FA 的优势在长序列,batch 太大反而会让 UB 放不下分块,退化成标准注意力。
- 用 FP16,别用 FP32:FA 的增量 Softmax 在 FP16 下数值更稳定,FP32 反而可能溢出。
仓库地址(纯文本,直接粘浏览器打开):
https://atomgit.com/cann/ops-transformer
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐
所有评论(0)