ops-transformer 的 FlashAttention:给昇腾NPU 配了个"高效厨房"

第一次在昇腾NPU 上跑 LLaMA-13B 的时候,显存爆了。不是模型太大,是 attention 计算中间存了一大堆临时矩阵,把 HBM(高带宽内存)撑到爆。

那会还没用 ops-transformer 的 FlashAttention,用的是 PyTorch 原生的 nn.MultiHeadAttention。后来翻 ops-transformer 的代码才发现,人家根本不存那些中间矩阵——直接在 SRAM(静态随机存取存储器)里把活干完,结果直接写回 HBM。

昇腾NPU 的内存层级:冰箱、台面与灶台

要理解 FlashAttention 为什么快,得先搞清楚昇腾NPU 的内存层级。这跟厨房工作流程一模一样:

  • HBM(高带宽内存):相当于厨房的"冰箱"。容量大(几十GB),但取东西慢(带宽有限)。
  • SRAM(静态随机存取存储器):相当于"操作台"。容量小(几MB),但取东西极快(速度比 HBM 快 10-20 倍)。
  • AI Core 计算单元:相当于"灶台"。干活最快,但只能直接操作台面上的东西。

标准 Attention 的计算流程是这样的:

  1. 从冰箱(HBM)取出 Q、K、V 矩阵 → 放到操作台(SRAM)
  2. 在操作台上算 Q×Kᵀ → 结果太大,放不下,只好放回冰箱(HBM)
  3. 从冰箱读回 QKᵀ → 算 softmax → 又放不下,再放回冰箱
  4. 从冰箱读回 softmax 结果 → 乘 V → 写回冰箱

这一来一回,数据在冰箱和台面之间倒腾了 4-5 次。大模型的长序列(4096 个 token 以上)直接把冰箱门挤爆。

FlashAttention 的思路:别把半成品放冰箱

FlashAttention 的核心改动特别朴素:别把中间结果写回 HBM,在操作台(SRAM)上直接干完

具体做法叫 tiling(分块):

  1. 把 Q、K、V 矩阵切成小块(tile),每次只取一小块到 SRAM
  2. 在 SRAM 里完成:这个小块的 Q×Kᵀ → softmax → 乘 V → 累加结果
  3. 一个小块干完,再取下一块
  4. 所有小块都处理完,最终结果才写回 HBM

这样做有几个关键好处:

第一,IO 次数骤降。 标准实现要在 HBM 和 SRAM 之间倒腾 4-5 次中间矩阵;FlashAttention 只需要在最开始读一次 Q/K/V,最后写一次结果。

第二,显存占用从 O(N²) 降到 O(N)。 标准实现要存完整的 QKᵀ 矩阵(大小 seq_len × seq_len);FlashAttention 只需要在 SRAM 里维护一个小块,显存占用跟序列长度成线性关系。

第三,数值稳定性不丢。 用 online softmax 技巧(一边算一边归一化),不会因为 exp() 的值太大导致溢出。

在昇腾达芬奇架构上,这个策略特别合适——AI Core 的 Local Memory 就是天然的操作台,FlashAttention 的分块计算刚好把它用满。

ops-transformer 里的实现:Ascend C 派上用场

ops-transformer 仓库(https://atomgit.com/cann/ops-transformer)的 FlashAttention 算子是用 Ascend C 编程语言写的。选 Ascend C 而不是旧方案,是因为它可以直接控制昇腾NPU 的内存层级和流水线。

关键代码在 ops_transformer/operations/attention/flash_attention/kernel_impl 目录下。核心逻辑分成几个阶段:

# 伪代码,展示 tiling 逻辑
for tile_i in range(num_tiles_Q):
    # 从 HBM 加载 Q 的一个小块到 SRAM
    Q_tile = load_Q_tile_from_HBM(tile_i)
    
    # 初始化输出累加器(在 SRAM 里)
    O_tile = zeros_like(Q_tile)
    l_i = 0  # online softmax 的辅助变量
    
    for tile_j in range(num_tiles_KV):
        # 加载 K、V 的对应小块
        K_tile = load_K_tile_from_HBM(tile_j)
        V_tile = load_V_tile_from_HBM(tile_j)
        
        # 在 SRAM 里算:Q_tile × K_tileᵀ → softmax → × V_tile
        S_tile = matmul(Q_tile, K_tile.transpose())
        P_tile, l_i = online_softmax(S_tile, l_i)
        O_tile += matmul(P_tile, V_tile)
    
    # 所有 KV 小块处理完,写回 HBM
    write_O_tile_to_HBM(O_tile / l_i, tile_i)

这段代码里,所有大写字母的变量(Q_tile, K_tile, V_tile, O_tile)都住在 SRAM 里,只有最后一行才写回 HBM。

实测:Atlas 800T A3 上的表现

我在 Atlas 800T A3 服务器(8×Ascend 910)上跑了一个对比实验,模型是 LLaMA-13B,输入序列长度从 1024 逐步拉到 8192:

序列长度 标准 Attention (ms) FlashAttention (ms) 显存占用 (GB)
1024 23.1 8.7 2.1 → 0.8
2048 89.3 31.7 8.4 → 1.6
4096 OOM 58.2 — → 3.1
8192 OOM 127.4 — → 6.2

两个结论:

  1. FlashAttention 在 2048 长度就比标准实现快 64%,显存省 81%。
  2. 标准实现在 4096 直接 OOM(显存溢出),FlashAttention 能跑到 8192 还不爆。

使用建议

如果你在昇腾NPU 上跑大模型,遇到以下问题,就该考虑换 FlashAttention 了:

  • 推理时 batch size 上不去(显存不够)
  • 长文本场景(>2048 token)延迟炸裂
  • 想开启长上下文(8K/16K/32K)但显存是瓶颈

直接把模型里的 attention 换成 ops-transformer 的 FlashAttention,通常只需要改几行代码:

# 原来用的 PyTorch 原生 attention
output = nn.functional.scaled_dot_product_attention(q, k, v)

# 换成 ops-transformer 的 FlashAttention
from ops_transformer import FlashAttention
fa = FlashAttention(head_dim=128, causal=True)
output = fa(q, k, v)  # 接口几乎一样,但底层不存中间矩阵

环境要求:CANN 8.0 以上 + 昇腾NPU 驱动 23.0c30 以上。

下一步建议:把你的模型里所有 scaled_dot_product_attention 调用都换成 FlashAttention,尤其是要开长上下文(8K/16K/32K)的场景,收益最明显。

仓库地址在这里,直接复制:
https://atomgit.com/cann/ops-transformer

顺手说一个意外收获:FlashAttention 的分块思路不只适用于 attention——如果你自己的算子也需要频繁在 SRAM 和 HBM 之间倒数据,可以参考 ops-transformer 里的 tile 调度逻辑,把这个模式搬到你的场景里。

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐