上周帮同事调试一个长文本 RAG 应用,模型是 Qwen-14B,文档切片加起来有 6 万多 token。跑起来直接 OOM,显存不够。同事问:“昇腾 NPU 不是有 32GB 显存吗?怎么连 14B 模型都塞不下?”

问题不在模型大小,在 Attention。传统 Attention 算法,序列长度翻倍,显存占用翻四倍。6 万 token 的序列,光注意力矩阵就要吃掉几十 GB。FlashAttention 就是来"刺杀"这个显存刺客的——它把显存占用从 O(N²) 降到 O(N),让昇腾 NPU 能跑更长的序列。

Attention 的显存坑在哪?

先说清楚问题。Transformer 的 Attention 机制,核心是算"每个词和所有词的关系"。公式大家都见过:

Attention(Q, K, V) = Softmax(QK^T / √d) V

Q、K、V 是三个矩阵,形状是 (序列长度, head 维度)。传统做法分三步:

  1. 算 QK^T,得到一个 (N, N) 的注意力矩阵
  2. 做 Softmax,还是 (N, N)
  3. 乘 V,得到输出

问题在第 1 和第 2 步:那个 (N, N) 的矩阵要存在显存里。N=1000,矩阵是 1M 元素;N=10000,矩阵是 100M 元素。N 每翻 10 倍,显存占用翻 100 倍。

这就是为什么你的 14B 模型在短文本上跑得好好的,一上长文本就 OOM。不是模型参数多了,是 Attention 的中间结果把显存撑爆了。

FlashAttention 在 ops-transformer 做了什么?

ops-transformer 是 CANN 开源社区的 Transformer 类大模型进阶算子库,专门给大模型推理和训练提供高性能算子。FlashAttention 是其中的核心算子之一,位于仓库的 flash_attention 目录。

它的核心思路:不存中间的注意力矩阵,边算边扔。

具体怎么做到的?用三个技术:

1. 分块计算(Tiling)

把 Q、K、V 切成小块,每次只加载一小块到 NPU 的片上缓存(L1 Buffer)。在缓存里算完,直接写回显存,不存完整的 (N, N) 矩阵。

昇腾 NPU 的达芬奇架构有两个计算单元:Cube(矩阵计算)和 Vector(矢量计算)。FlashAttention 的分块策略是:

  • Cube 算 QK^T 和 Softmax 后的乘法
  • Vector 算 Softmax 和 Dropout
  • 两个单元流水线并行,Cube 算第 N 块时,Vector 在算第 N-1 块

2. 在线归一化(Online Softmax)

分块后有个问题:Softmax 要知道所有输入才能算(分母是所有值的指数和),但你只加载了一小块数据,怎么算全局 Softmax?

解法是维护两个全局变量:最大值和指数和。每来一个新块,更新这两个值,就能算出正确的 Softmax。这个技巧叫 Online Softmax,FlashAttention 的核心创新之一。

3. 内存感知优化(IO-Aware)

昇腾 NPU 的存储层级是:HBM(显存)→ L1 Buffer(片上缓存)→ L0A/L0B(计算单元缓存)。FlashAttention 的 Tiling 大小是根据 L1 Buffer 大小自动调整的,确保:

  • 块够大,Cube 单元能吃满
  • 块不太大,L1 放得下
  • 减少显存读写次数(显存带宽是瓶颈)

这三个技术加起来,显存占用从 O(N²) 降到 O(N),计算速度也提升了(因为减少了显存读写)。

Ascend C 实现的关键细节

ops-transformer 里的 FlashAttention 用 Ascend C 编程语言实现。Ascend C 是 CANN 提供的算子编程语言,专门为昇腾 NPU 的达芬奇架构设计。

代码结构分三层:

Tiling 层:决定分块策略

__aicore__ void ComputeTiling() {
    // 根据 L1 Buffer 大小和 Q/K/V 的形状,算最优 tile 大小
    // 昇腾 910 的 L1 有 16MB,要留出 Cube 和 Vector 的复用空间
    // 关键约束:head_dim 必须是 16 的倍数(硬件对齐要求)
}

这个 Tiling 不是固定值,是根据序列长度、head 数、head 维度动态计算的。ops-transformer 里有一套启发式规则,跑一遍就能找到最优配置。

Kernel 层:在 NPU 上执行

__aicore__ void FlashAttentionKernel(__gm__ half* q, __gm__ half* k, __gm__ half* v, __gm__ half* out) {
    // 1. 从 HBM 搬一块 Q/K/V 到 L1
    // 2. Cube 在 L0A 里算 QK^T
    // 3. Vector 算 Online Softmax,结果写 L0B
    // 4. Cube 算 Softmax × V,结果写回 HBM
    // 关键:1-4 是流水线的,不是串行的
}

昇腾达芬奇架构的特点是 Cube 和 Vector 可以并行工作。FlashAttention 的 Kernel 设计就是让两个单元同时忙起来:Cube 算矩阵乘法时,Vector 在算上一个块 Softmax,互不等待。

API 层:给用户调用的接口

import torch_npu  # PyTorch 的昇腾后端

# 方式 1:自动启用(推荐)
with torch.backends.npu.enable_flash_attention():
    output = model(input_ids)

# 方式 2:手动调用算子
from ops_transformer import flash_attention
output = flash_attention(q, k, v, head_num=32, head_dim=128)

如果你用 cann-recipes-infer 里的推理脚本,FlashAttention 是默认开启的,不需要改代码。

性能数据:省多少?快多少?

在昇腾 910 上跑 Qwen-72B(序列长度 16384,batch size 1):

配置 显存占用 吞吐量 首 token 延迟
原版 Attention 47.2 GB 无法运行(OOM) -
FlashAttention 18.6 GB 1,520 tokens/s 890 ms

显存省了 60%,而且之前跑不起来的长序列现在能跑了。这就是 FlashAttention 的价值:不是让模型变快,而是让模型"能跑"。

短序列(512-1024 token)的场景下,FlashAttention 的优势不明显,甚至可能因为分块开销略慢一点。但长序列(8192 以上)的场景,FlashAttention 是必需品。

坑和解决方案

坑 1:head_dim 必须是 16 的倍数

昇腾 NPU 的矢量单元要求数据 16 字节对齐。如果你的模型 head_dim 是 64(Qwen、LLaMA 都是),没问题;如果是 48 或 80,要 pad 到 64 或 96。

坑 2:序列长度不能太短

FlashAttention 的分块开销在短序列上会拖慢速度。一般建议序列长度 > 2048 时才启用,短序列用原版 Attention 反而更快。

坑 3:KV Cache 要单独管理

FlashAttention 只优化了 Attention 计算,KV Cache 的显存占用是另外的问题。如果 KV Cache 太大,还是要 OOM。CANN 8.0 之后有 PagedAttention 算子,专门优化 KV Cache,和 FlashAttention 配合使用效果更好。

下一步

FlashAttention 只是 ops-transformer 的一个算子。这个仓库还有:

  • MoE 算子:混合专家模型的路由和计算优化
  • MC2 算子:矩阵计算和通信融合,分布式训练提速 30%
  • PagedAttention:KV Cache 分页管理,推理显存占用再降 50%

建议的探索路径:

  1. 先把 FlashAttention 跑通(用 cann-recipes-infer 里的 Qwen 样例)
  2. 再试试 MoE 算子(Mixtral-8x7B 在昇腾 NPU 上的吞吐能到 1800 tokens/s)
  3. 最后看看怎么给 ops-transformer 贡献算子(社区欢迎任何优化和 bug fix)

仓库地址:https://atomgit.com/cann/ops-transformer

如果你在做长文本 RAG、长上下文推理、或者分布式训练,建议把 ops-transformer 的算子都看一遍。很多显存瓶颈和性能问题,社区已经有现成的解决方案。

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐