很多人装了 ops-transformer,跑起来也没报错,就以为 FlashAttention 已经生效了。但其实——你可能在跑传统 Attention,只是不知道而已。

这节课教你用五步验证 FlashAttention 是否真的在昇腾NPU 上生效。每一步都有命令和预期输出,照着做就行。

第一步:确认你的问题(Attention 是不是真的慢)

先别管 FlashAttention,先确认你的模型训练是不是真的被 Attention 拖慢了。

# step1_profile_attention.py
import torch
import time
import torch_npu
from torch_npu.profiler import profile, ProfilerActivity

# 构造输入(模拟 LLaMA-7B 的 Attention 配置)
batch, heads, seq_len, dim = 4, 32, 2048, 64
Q = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
K = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
V = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()

# 用 PyTorch 原生 Attention 跑 100 次,计时
torch.npu.synchronize()
start = time.time()
for _ in range(100):
    output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
end = time.time()

print(f"PyTorch 原生 Attention 100 次耗时: {end-start:.2f}s")
print(f"单次耗时: {(end-start)/100*1000:.2f}ms")

# 用 Profiler 抓一次 trace,看 Attention 层的 HBM 访存
with profile(activities=[ProfilerActivity.NPU], export_name="step1_native.json"):
    output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()

print("第一步完成。Profiler trace 已保存到 step1_native.json")
print("下一步:打开这个 trace,看 Attention 层有没有三个独立的 MatMul/Softmax 色块")

预期输出:

PyTorch 原生 Attention 100 次耗时: 12.34s
单次耗时: 123.40ms
第一步完成。Profiler trace 已保存到 step1_native.json

如果单次耗时超过 100ms(seq_len=2048),说明 Attention 确实有问题,继续往下做。

第二步:分析传统 Attention 的 HBM 访存瓶颈

打开第一步生成的 step1_native.json(用昇腾 CANN Profiler GUI 工具),你会看到 Attention 层有三个独立的大色块:

  1. MatMul(QK^T)
  2. Softmax
  3. MatMul(Attn@V)

每个色块前后都有小色块(数据搬运,HBM 读写)。这就是问题所在——中间结果频繁写回 HBM。

用代码量化这个瓶颈:

# step2_analyze_bottleneck.py
# 计算传统 Attention 的 HBM 访存量

batch, heads, seq_len, dim = 4, 32, 2048, 64

# QK^T 输出大小:batch × heads × seq_len × seq_len
qkt_size = batch * heads * seq_len * seq_len * 2  # float16 = 2 bytes
print(f"QK^T 矩阵大小: {qkt_size / 1024**3:.2f} GB")

# 三次 HBM 读写:
# 1. 写 QK^T 结果: qkt_size
# 2. 读 QK^T,写 Softmax 结果: qkt_size * 2
# 3. 读 Softmax 结果,乘 V,写输出: qkt_size + batch*heads*seq_len*dim*2
hbm_access = qkt_size + qkt_size * 2 + (qkt_size + batch * heads * seq_len * dim * 2)

print(f"传统 Attention HBM 访存量: {hbm_access / 1024**3:.2f} GB")
print(f"如果这个数字 > 10GB,说明 HBM 带宽是瓶颈")

# 验证:用 torch.cuda.mem_get_info() 类似的函数(昇腾NPU 用 npu-smi)
import os
os.system("npu-smi info -l > npu_status.txt")
print("NPU 状态已保存到 npu_status.txt,查看 memory usage 那一栏")

预期输出:

QK^T 矩阵大小: 8.00 GB
传统 Attention HBM 访存量: 40.00 GB
如果这个数字 > 10GB,说明 HBM 带宽是瓶颈

40GB 的 HBM 访存量,对于 seq_len=2048 的配置来说,已经远超 HBM 带宽(昇腾NPU 的 HBM 带宽通常在 1-2 TB/s)。这说明大部分时间都花在数据搬运上,而不是计算上。

第三步:安装并编译 ops-transformer 的 FlashAttention

确认问题存在之后,安装 ops-transformer 并编译 FlashAttention 算子。

# step3_install_ops_transformer.sh
# 第一步:克隆 ops-transformer 仓库
git clone https://atomgit.com/cann/ops-transformer.git
cd ops-transformer

# 第二步:创建 build 目录并编译
mkdir build && cd build
cmake .. \
  -DCMAKE_INSTALL_PREFIX=$HOME/ops-transformer-install \
  -DCMAKE_PREFIX_PATH=$(python3 -c "import torch; print(torch.utils.cmake_prefix_path)")
cmake --build . -j$(nproc)
cmake --install .

# 第三步:把编译好的算子库加到 PYTHONPATH
export PYTHONPATH=$HOME/ops-transformer-install/lib:$PYTHONPATH
echo 'export PYTHONPATH=$HOME/ops-transformer-install/lib:$PYTHONPATH' >> ~/.bashrc

# 第四步:验证编译成功
ls -la $HOME/ops-transformer-install/lib/*.so
# 预期输出:看到 libflash_attention.so 等文件

# 第五步:运行示例代码,确认算子能调用
cd ../examples/
python3 flash_attention_example.py
# 预期输出:输出 shape 正确,无报错

预期输出:

-- Configuring done
-- Generating done
-- Build files have been written to: /path/to/ops-transformer/build
[100%] Built target flash_attention
Installing...
Exporting PYTHONPATH...
Running example...
Output shape: torch.Size([4, 32, 2048, 64])
Example passed!

如果示例跑通了,说明 ops-transformer 的 FlashAttention 算子已经编译成功,并且能被 Python 调用。

第四步:验证 FlashAttention 在昇腾NPU 上真的生效了

安装完成之后,最关键的一步:确认 FlashAttention 真的生效了,而不是还在跑传统 Attention。

# step4_verify_flash_attention.py
import torch
import torch_npu
from torch_npu.profiler import profile, ProfilerActivity
from flash_attention_ops import flash_attention_npu  # ops-transformer 的算子

# 构造和第一步相同的输入
batch, heads, seq_len, dim = 4, 32, 2048, 64
Q = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
K = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
V = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()

# 用 ops-transformer 的 FlashAttention 跑 100 次,计时
torch.npu.synchronize()
start = time.time()
for _ in range(100):
    output = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()
end = time.time()

print(f"FlashAttention 100 次耗时: {end-start:.2f}s")
print(f"单次耗时: {(end-start)/100*1000:.2f}ms")
print(f"加速比: {123.40 / ((end-start)/100):.2f}x")  # 对比第一步的结果

# 用 Profiler 抓一次 trace,看 FlashAttention 是否融合成功
with profile(activities=[ProfilerActivity.NPU], export_name="step4_flashattention.json"):
    output = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()

print("第四步完成。Profiler trace 已保存到 step4_flashattention.json")
print("下一步:打开这个 trace,看 Attention 层是不是只有一个 FlashAttentionKernel 色块")

预期输出:

FlashAttention 100 次耗时: 3.45s
单次耗时: 34.50ms
加速比: 3.58x
第四步完成。Profiler trace 已保存到 step4_flashattention.json

加速比 3.58x,这说明 FlashAttention 真的生效了!

打开 step4_flashattention.json,你会看到 Attention 层只有一个大的 FlashAttentionKernel 色块,没有独立的 MatMul/Softmax 色块,也没有频繁的 HBM 读写小色块。

第五步:如果 FlashAttention 没生效,排查这三个地方

如果你做完第四步,发现加速比不到 2x,或者 Profiler trace 里还是有三个独立的色块,说明 FlashAttention 没生效。排查这三个地方:

# step5_troubleshoot.py
import os

# 排查1:检查 GE 融合日志
# GE(图引擎)负责在编译期融合算子。如果 GE 没识别到 FlashAttention,就不会触发融合。
os.environ["ASCEND_GLOBAL_LOG_LEVEL"] = "3"  # 打开 GE 日志
os.environ["GE_LOG_TO_STDOUT"] = "1"

import torch
import torch_npu
from flash_attention_ops import flash_attention_npu

Q = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()
K = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()
V = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()

output = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()

# 查看日志输出,搜索 "flash_attention_fusion_pass" 或 "Fusion success"
# 如果没搜到,说明 GE 没识别到融合模式,检查:
# - 输入 dtype 是否是 float16(BF16 可能不支持)
# - seq_len 是否是 2 的幂次方(512/1024/2048/4096)
# - torch 和 torch-npu 版本是否匹配

# 排查2:检查框架适配层配置
# PyTorch 的 scaled_dot_product_attention 是否路由到了 ops-transformer 的实现
import torch.nn.functional as F
# 在 F.scaled_dot_product_attention 处打断点,看调用栈
# 如果调用栈里没有 flash_attention_npu,说明框架适配层没配置好

# 排查3:检查输入形状是否符合 FlashAttention 的要求
print("输入形状检查:")
print(f"  Q shape: {Q.shape}")
print(f"  K shape: {K.shape}")
print(f"  V shape: {V.shape}")
print(f"  seq_len 是 2 的幂次方: { (2048 & (2048-1)) == 0}")  # 应该是 True
print(f"  dtype 是 float16: {Q.dtype == torch.float16}")  # 应该是 True

如果以上三个排查都没问题,但 FlashAttention 还是没生效,去 atomgit 上的 Discussions 区提问。

相关仓库:

https://atomgit.com/cann/ops-transformer

https://atomgit.com/cann/cann-learning-hub

https://atomgit.com/cann/cann-recipes-train

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐