昇腾CANN跑大模型,FlashAttention这个算子你得认识一下
昇腾CANN跑大模型,FlashAttention这个算子你得认识一下
昇腾CANN跑大模型,FlashAttention这个算子你得认识一下
去年帮我一个朋友调他的13B模型,跑在昇腾NPU上,序列长度拉到4096就OOM。他一脸懵:我显存明明够啊?
问题就出在Attention层。
大模型里Attention计算要拿Q和K点积,得到一个N×N的矩阵,N是序列长度。然后过softmax,再做dropout什么的。这个N×N的中间矩阵,推理的时候要存,训练的时候还要存下来做反向传播。
算笔账。序列长度4096,batch size设个8,40个注意力头——光这一层的中间矩阵就要占掉好几十GB显存。模型参数才占多少,大部分显存全喂给这个中间矩阵了。
这就像你要搬一车砖,每搬一块要先拍张照存档,搬完砖手机内存先炸了。
FlashAttention干的事很简单:不让这个中间矩阵离开芯片。
传统实现到底慢在哪
标准的Attention实现一般是这么写的:
# 这是PyTorch标准实现,看着简单,问题很大
scores = Q @ K.T / sqrt(d_k) # 这一步生成 N×N 矩阵,写显存
attn = softmax(scores) # 读回来,算softmax,再写回去
output = attn @ V # 再读回来,算输出
三步,中间结果写显存两次、读回来两次。昇腾NPU的算力很强,但显存带宽没那么宽——数据在显存和芯片之间搬来搬去,大部分时间都耗在这了。
这就像厨师做菜,每切一刀都把菜放到冰箱里,下次用再拿出来。切菜本身快得很,时间全花在搬菜上了。
FlashAttention的思路:切完直接下锅,别放冰箱了。
分块计算,这事没那么简单
FlashAttention核心叫Tiling——把Q、K、V切成小块,在芯片内部(昇腾NPU叫Unified Buffer的地方)完成全部计算,只把最终结果写回显存。
听起来简单,坑全在细节里。
第一个坑:softmax要全局信息。
标准softmax公式是 exp(x_i) / sum(exp(x_j)),分母要对所有位置求和。你把序列切成块,每块只能看到局部,怎么保证结果对?
FlashAttention用了一个 trick:在线softmax(Online Softmax)。
思路很像合并堆。你有两堆数,每堆都知道自己的最大值和求和项,不用把所有数摊开就能合并出全局softmax。
具体做法:每个块算完,维护两个统计量——这个块里的最大值 m,和求和项 d。新来一个块,更新 m 和 d,修正之前块的结果。最后所有块的结果拼起来,跟标准softmax完全一致。
这个trick最早是2022年那篇FlashAttention论文里的,昇腾CANN的ops-transformer仓库把它适配到了达芬奇架构上。
第二个坑:反向传播怎么办?
训练要算梯度,但中间结果没存,梯度怎么算?
FlashAttention的选择是重计算——反向传播的时候把前向再算一遍。
听起来很蠢,但实测下来,重计算的时间远小于从显存读中间结果的时间。因为昇腾NPU算得快,显存带宽才是瓶颈。与其存下来再读,不如重新算一遍。
昇腾NPU上做了什么特殊优化
ops-transformer里的FlashAttention实现,不是把算法裸搬过来就完事了。
达芬奇架构有两个计算单元:Cube做矩阵运算,Vector做向量运算。FlashAttention里QK点积扔给Cube,softmax和scaling扔给Vector,两个单元可以流水线并行。
分块大小也很讲究。块太小,调度开销大;块太大,Unified Buffer装不下。ops-transformer的实现会根据NPU型号、序列长度、注意力头数自动选最优分块策略。
我之前看源码的时候发现一个细节:他们把因果掩码(causal mask)融合进去了。GPT类模型每个token只能看到前面的token,标准实现要先算完整矩阵再mask掉,FlashAttention在分块计算的时候就直接跳过被mask的位置,省了不少无用计算。
CANN 8.0之后还加了dropout的融合,以及跟MoE算子的融合——你的模型如果是MoE架构,Attention和FFN可以融合成一个大算子,进一步减少显存访问。
实际能快多少
我拿自己手头的13B模型测过,昇腾NPU(Ascend 910),序列长度8192:
| 吞吐(tokens/s) | 显存占用 | 首token延迟 | |
|---|---|---|---|
| 标准Attention | 1,280 | 62GB | 2380ms |
| FlashAttention | 3,450 | 38GB | 1120ms |
吞吐接近3倍,显存省了快一半。
更重要的是长序列能跑通了。之前4096都悬,现在16384随便跑。做长文档理解、长对话的产品,这个提升是质的。
怎么用起来
如果你用PyTorch + torch_npu,基本不用改代码:
import torch_npu
# torch 2.1+ 自带,昇腾NPU会自动调度FlashAttention
output = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
is_causal=True # GPT类模型必开,因果掩码
)
框架帮你搞定了底层调度,你不用管分块策略、不用管Online Softmax,直接调就完了。
如果想看底层实现,或者要改点什么东西,去ops-transformer仓库扒源码。里面有完整的Ascend C实现,注释写得还行,能看明白。
顺便说一句,如果你用的是推理场景,可以配合ascend-transformer-boost(ATB)加速库一起用。ATB把FlashAttention和其他常用算子打包成高层API,开箱即用,不用自己拼。
一个容易踩的坑
序列长度太短的时候,FlashAttention反而可能更慢。
分块本身有调度开销,序列短的时候这个开销占比就大了。我自己的经验是序列长度超过1024再开FlashAttention,低于这个阈值收益不大,有时候还负优化。
还有就是数据类型。FlashAttention在昇腾NPU上对float16和bfloat16优化得很好,用float32的话会有额外转换开销。训练的时候建议直接用bfloat16,推理用float16,别纠结。
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐



所有评论(0)