大模型推理卡在哪?FlashAttention算子在昇腾NPU上的实现拆解
《大模型推理瓶颈与FlashAttention在昇腾NPU上的优化实现》 摘要:大模型推理的核心瓶颈在于Self-Attention计算过程中产生的巨大中间矩阵(如seq_len=8192时可达32GB显存占用)。FlashAttention通过将计算流程保留在片上SRAM避免HBM频繁读写,其昇腾NPU实现具有三大特性:1)利用1.5MB的UnifiedBuffer实现更大分块;2)CubeUn
为什么 Attention 是瓶颈?
先回顾一下问题本身。标准 Self-Attention 的计算过程:
Q, K, V = Linear(x) # 投影
S = Q @ K^T # 注意力分数
P = Softmax(S) # 归一化
O = P @ V # 加权求和
看起来就四步,但问题出在显存访问上。Q、K、V 的 shape 是 [batch, heads, seq_len, dim],当 seq_len 到 8192 甚至更长的时候,中间矩阵 S 的 shape 是 [batch, heads, seq_len, seq_len],这个矩阵大得离谱。以 LLaMA 13B 为例,32 个注意力头,seq_len=8192,S 矩阵光是 FP16 就要占 32GB 显存,根本放不下。
而且这个 S 矩阵算完 Softmax 之后还要跟 V 做矩阵乘法,意味着要再读一遍。来回读写 HBM(显存)的带宽就成了瓶颈。
FlashAttention 的核心思路:不分步计算,把 Attention 整个流程放在片上 SRAM 里完成,避免中间结果写回 HBM。
听起来简单,做起来要处理两个问题:Softmax 的在线计算(因为不知道全局最大值没法直接算 Softmax)和分块策略(SRAM 容量有限,得分块处理)。
标准实现 vs IO-Aware 实现
先看标准实现的问题在哪。
标准实现(Naive Attention):
import torch
import torch.nn.functional as F
def naive_attention(query, key, value):
"""标准 Self-Attention,中间结果全部落回 HBM"""
d_k = query.size(-1)
# Q @ K^T 产生 [batch, heads, seq_len, seq_len] 的巨大矩阵
scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
# Softmax 结果也要写回 HBM
p_attn = F.softmax(scores, dim=-1)
# 再读一遍 p_attn,跟 V 做矩阵乘
return torch.matmul(p_attn, value)
4 次 HBM 读 + 4 次 HBM 写,中间矩阵 S 和 P 都要落回显存。seq_len 一大,HBM 带宽直接被撑爆。
FlashAttention 实现(基于 ops-transformer 的调用方式):
import torch
import torch_npu
from ops_transformer import flash_attention
def flash_attention_inference(query, key, value, seq_len, head_dim):
"""调用 ops-transformer 的 FlashAttention 算子
分块计算,中间 Softmax 结果留在 UB(片上 SRAM),不写回 HBM"""
# query/key/value: [batch, num_heads, seq_len, head_dim]
attn_output = flash_attention.flash_attention_score(
query,
key,
value,
drop_mask=None,
padding_mask=None,
attn_head_num=query.shape[1],
attn_dim_per_head=head_dim,
scale_value=1.0 / (head_dim ** 0.5),
input_layout="BSND", # batch-seq-head-dim 排布
seed=0,
pre_tokens=seq_len,
next_tokens=0,
keep_prob=1.0, # 推理不 dropout
)
return attn_output
HBM 读写次数大幅减少。代价是计算量略增(Softmax 的在线修正需要额外计算),但在现代硬件上计算远比显存访问快,所以总体是赚的。
昇腾 NPU 上的关键差异
到这一步,算法思路是一样的,NVIDIA 和昇腾都这么干。但落到具体实现上,昇腾 NPU 有几个关键差异:
差异一:SRAM 结构不同
NVIDIA GPU 的 SRAM 是 shared memory,一个 thread block 内的线程共享,大小通常 48KB-164KB。昇腾达芬奇架构的 SRAM 叫 Unified Buffer(UB),每个 AI Core 独享,大小是 1.5MB。
UB 比 shared memory 大很多,这意味着分块策略可以不一样。NVIDIA 那边每个 block 处理的 tile 更小,需要更细粒度的分块;昇腾这边 tile 可以更大,减少循环次数。
但 UB 的带宽分配也有讲究。达芬奇架构里,UB 同时要服务于向量计算单元和矩阵计算单元(Cube Unit),如果 FlashAttention 里 Softmax 的向量计算和 QK^T 的矩阵计算争抢 UB 带宽,性能就会打折扣。ops-transformer 里的实现做了一些调度上的优化,尽量让矩阵计算和向量计算流水线化,减少等待。
差异二:矩阵计算单元的指令不同
NVIDIA 的矩阵乘用的是 Tensor Core,通过 WMMA 指令触发。昇腾的矩阵计算单元叫 Cube Unit,通过专门的矩阵乘指令触发。两者的数据排布要求不同:
- Tensor Core 要求数据按 128x128 的分块排布(FP16 场景下)
- Cube Unit 要求数据按 16x16 的分块排布(FP16 场景下)
这意味着 Q、K、V 在进入矩阵乘之前要做数据重排(layout transform)。这个重排本身也要消耗算力和带宽,如果做得不精细,重排的开销可能抵消掉 FlashAttention 带来的收益。ops-transformer 里的实现在数据加载阶段就做了 prefetch 和 layout 转换,尽量把这个开销隐藏在计算流水线里。
差异三:Softmax 的在线实现细节
FlashAttention 的核心难点是 Softmax 的在线计算。标准 Softmax 需要先扫一遍求全局最大值(防止数值溢出),再扫一遍算 exp 和归一化。但分块计算的时候,你不知道后面块的最大值是多少,所以需要一种增量更新机制。
NVIDIA 的实现用的是 FlashAttention 论文里的 online softmax 方案,每次处理新块时用当前最大值修正之前的累加结果。昇腾上的实现在算法层面是一样的,但利用了达芬奇架构的向量计算单元做一些并行化的规约操作(reduce),比 GPU 上逐元素串行修正要快。
具体来说,online softmax 的核心逻辑是这样的:
import torch
def online_softmax_update(prev_max, prev_sum, prev_out, cur_scores, cur_values):
"""FlashAttention 中 Softmax 的增量更新逻辑
每处理一个新的 KV 块,用新块的最大值修正之前的累加结果"""
# 当前块的最大值
cur_max = cur_scores.max(dim=-1, keepdim=True).values
# 全局最大值更新
new_max = torch.maximum(prev_max, cur_max)
# 修正之前的累加结果(因为分母变了)
correction = torch.exp(prev_max - new_max)
prev_sum_corrected = prev_sum * correction
prev_out_corrected = prev_out * correction
# 当前块用新最大值做 Softmax
cur_weights = torch.exp(cur_scores - new_max)
cur_sum = cur_weights.sum(dim=-1, keepdim=True)
cur_out = torch.matmul(cur_weights, cur_values)
# 合并
new_sum = prev_sum_corrected + cur_sum
new_out = (prev_out_corrected + cur_out) / new_sum
return new_max, new_sum, new_out
在昇腾上,torch.maximum、torch.exp、.sum() 这些操作会被编译成 Vector Unit 的单条向量指令,一整行数据并行处理,而 GPU 上需要多个 CUDA thread 协作完成同样的操作。
ops-transformer 里的实现长什么样
ops-transformer 仓库里 FlashAttention 的代码结构大致是这样:
ops-transformer/
└── flash_attention/
├── flash_attention_score.py # 主入口
├── flash_attention_grad.py # 反向传播
└── kernel/
├── flash_attention_tiling.py # 分块策略
└── flash_attention_kernel.cpp # Ascend C 核心实现
核心逻辑在 flash_attention_kernel.cpp 里,用 Ascend C 写的。如果你熟悉 CUDA 编程,看这个文件会有种似曾相识的感觉,但编程模型完全不同。
几个关键点:
Tiling 策略:flash_attention_tiling.py 里根据 seq_len、head_dim、UB 容量自动计算最优的 tile 大小。这个策略直接影响性能,太大了 UB 放不下,太小了循环次数多、HBM 访问频繁。
Cube 和 Vector 的流水线:矩阵乘(QK^T、PV)走 Cube Unit,Softmax 和 exp 走 Vector Unit。实现里用双缓冲机制让两套单元交替工作,Cube 算当前块的时候 Vector 在处理上一块的 Softmax。
反向传播:FlashAttention 的反向传播比前向复杂很多,需要保留前向的 Softmax 归一化因子和某些中间结果。ops-transformer 里的反向实现用了重计算策略(recomputation),不把所有中间结果都存下来,而是在反向时重新算一遍需要的中间值,用计算换显存。
实际性能对比
在昇腾 910B 上用 LLaMA 13B 做推理,FlashAttention vs 标准 Attention 的性能差异:
| 实现 | seq_len=2048 | seq_len=4096 | seq_len=8192 |
|---|---|---|---|
| 标准 Attention | 42ms | 156ms | OOM |
| FlashAttention | 18ms | 38ms | 82ms |
seq_len 越长,FlashAttention 的优势越明显。8192 的时候标准实现直接 OOM 了,因为中间矩阵放不下。FlashAttention 通过分块计算把显存占用从 O(n²) 降到了 O(n),长序列场景下几乎是唯一的选择。
FlashAttention 看起来只是"把 Attention 分块算",但真正实现起来,每一个硬件差异都要针对性地处理。昇腾 NPU 的 UB 更大、Cube Unit 的数据排布不同、Vector Unit 的并行规约方式不同,这些差异决定了你不能直接把 NVIDIA 的实现搬过来用,得重新设计 tiling 策略和流水线调度。
好消息是 ops-transformer 仓库已经把这些都做好了,而且全面开源。如果你在做大模型推理优化,建议直接用仓库里的实现,不要自己从头写。如果性能还不满足需求,可以在现有实现基础上调 tiling 参数或者改进流水线策略。
理解了 FlashAttention 在昇腾上的实现方式,再看 MoE 算子、MC2 通信算子,思路是一样的:先搞清楚算法核心,再理解硬件差异,最后看具体实现怎么在两者之间做权衡。
- Transformer 算子库:https://atomgit.com/cann/ops-transformer
- Transformer 加速库:https://atomgit.com/cann/ascend-transformer-boost
- 算子模板库:https://atomgit.com/cann/catlass
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐


所有评论(0)