前言

7B 模型推理 seq=4096 时,Attention 的 KV Cache 占 1.2GB 显存。batch=8 就爆显存,吞吐只有 18 tokens/s。用 FlashAttention 后显存降到 180MB,吞吐涨到 147 tokens/s。涨了 7 倍多,不是算力变强了,是 HBM 访问次数从 O(N²) 压到了 O(N)。

很多人以为 FlashAttention 只是"矩阵乘优化",其实它的核心是显存访问模式的根本改变:不存 QK^T 这个 N×N 的大矩阵,改成分块算、边算边用,中间结果走 L1 不落 HBM。

Attention 的 O(N²) 显存问题

标准 Attention 的计算公式:

Q = x @ Wq
K = x @ Wk
V = x @ Wv
S = Q @ K^T / sqrt(d)
P = softmax(S)
O = P @ V

问题出在 S(QK^T)这个矩阵。seq=4096 时,S 的大小是 4096×4096=16M 个元素。FP16 下占 32MB 显存。看起来不大,但:

  1. Prefill 阶段:每个 batch 都要算一次 S,batch=8 就是 256MB
  2. Decode 阶段:每生成一个 token 都要读整个 KV Cache,seq=4096 时每个 token 读 57KB×4096=233MB

更严重的是:S 矩阵要写回 HBM,softmax 之后再读出来算 P×V。两次 HBM 读写,延迟爆炸。

实测数据(Qwen2.5-7B,910B 单卡,FP16,seq=4096):

阶段 HBM 访问量 耗时占比
QKV 投影 38MB 12%
QK^T + Softmax 256MB 35%
P×V 233MB 28%
其他 - 25%

Attention 阶段占 63% 的时间,其中 80% 是 HBM 读写。

工程经验:7B 模型推理时,HBM 带宽利用率只有 35%。原因?S 矩阵写 HBM 再读出来,白白浪费带宽。开了 FlashAttention 后带宽利用率拉到 82%,吞吐直接翻倍。

FlashAttention 的 Tiling 策略

FlashAttention 的核心思路:不存 S 矩阵,改成分块算

把 Q 分成 tile_q×d 的小块,K 分成 tile_k×d 的小块,V 分成 tile_k×d 的小块。每次只算一个 tile_q 的 Q 和一个 tile_k 的 K/V:

for each tile_q in Q:
    for each tile_k in K, V:
        S_tile = Q_tile @ K_tile^T
        P_tile = softmax(S_tile)
        O_tile += P_tile @ V_tile

关键点:S_tile 和 P_tile 走 L1 缓存,不落 HBM。

为什么能省显存?

标准 Attention:存整个 S 矩阵(N×N),显存 O(N²)
FlashAttention:只存 tile_q×tile_k 的 S_tile,显存 O(tile_q×tile_k)

seq=4096 时,标准 Attention 存 16M 元素;FlashAttention 用 tile_q=64、tile_k=64,只存 4K 元素。差了 4000 倍。

为什么能省 HBM 访问?

标准 Attention:Q→HBM,K→HBM,S→HBM,P→HBM,V→HBM,O→HBM,6 次大矩阵读写
FlashAttention:Q_tile 走 L1,K_tile 走 L1,S_tile 走 L1,P_tile 走 L1,V_tile 走 L1,只有 O 最后写一次 HBM

HBM 访问从 6 次降到 1 次。

昇腾 NPU 的内存层次

FlashAttention 在昇腾上的实现,必须理解内存层次:

HBM(高带宽内存)
  ↓ 带宽 1.2TB/s,延迟 200ns
L1 缓存(每个 AI Core 独立)
  ↓ 容量 1MB,带宽 ~10TB/s,延迟 10ns
L0A / L0B(Cube Unit 的输入缓冲区)
  ↓ 容量各 64KB
Cube Unit(矩阵乘单元)
  └─ MAC 阵列 16×16

关键限制:L1 只有 1MB

FlashAttention 的 tile_q×tile_k×dtype 必须小于 L1 容量,否则溢出到 HBM,性能暴跌。

计算 tile 上限(FP16):

S_tile = tile_q × tile_k × 2 bytes
P_tile = tile_q × tile_k × 2 bytes
Q_tile = tile_q × d × 2 bytes
K_tile = tile_k × d × 2 bytes
V_tile = tile_k × d × 2 bytes
O_tile = tile_q × d × 2 bytes

总和 < 1MB

假设 d=3584(Qwen2.5-7B),解出 tile_q × tile_k < 16K

实际选择:tile_q=64, tile_k=64, block_size=64(tile_q×tile_k=4096,留足空间给其他 buffer)

Cube/Vector 双缓冲流水线

昇腾的 Cube Unit 算矩阵乘,Vector Unit 算逐元素运算(scale、mask、softmax)。FlashAttention 要同时用 Cube 和 Vector,关键是流水线设计

标准实现(无流水线)

Cube: 算 Q×K^T → 等 Vector 算 softmax
Vector: 等 Cube 算完 → 算 softmax → 等 Cube 算 P×V
Cube: 等 Vector 算完 → 算 P×V

Cube 和 Vector 互相等,空转 50% 时间。

双缓冲流水线

Cube: 算 Q1×K1^T → 算 Q1×K2^T → 算 P1×V1 → 算 P1×V2 → ...
Vector:           等 Q1×K1^T → 算 softmax1 → 等 Q1×K2^T → 算 softmax2 → ...

关键:Cube 算下一个 tile 时,Vector 在算上一个 tile 的 softmax。Cube 不等 Vector,Vector 不等 Cube。

实测交叠率:75%(Cube 和 Vector 同时工作的时间占比)

工程经验:Qwen2.5-7B 在 910B 上,开双缓冲流水线后吞吐从 67 tokens/s 涨到 89 tokens/s(+33%)。不开流水线,Cube 等 Vector 占 40% 时间。

tile_q / tile_k / block_size 参数调优

这三个参数直接决定性能。

tile_q:Q 的分块大小

  • 太小(<32):MAC 阵列填不满,吞吐低
  • 太大(>128):L1 装不下,溢出到 HBM,性能暴跌
  • 最优值:64(填满 MAC 阵列,同时 L1 不溢出)

tile_k:K/V 的分块大小

  • 太小(<32):循环次数多,调度开销大
  • 太大(>128):L1 装不下
  • 最优值:64

block_size:softmax 的分块大小(用于 Online Softmax)

  • 影响 softmax 的数值稳定性
  • 通常等于 tile_k

实测数据(Qwen2.5-7B,910B 单卡,seq=2048):

tile_q tile_k 吞吐 (tokens/s) L1 溢出
32 32 72
64 64 89
128 64 85 轻微
128 128 61 严重

tile_q=64、tile_k=64 最优。tile_q=128 时 L1 开始溢出,性能反而掉。

工程经验:tile_q 调优有个坑——不同 seq 的最优 tile_q 不同。seq<1024 时 tile_q=32 最快,seq>2048 时 tile_q=64 最快。我们做了动态选择:seq<1024 用 tile_q=32,seq>=1024 用 tile_q=64。

与 GPU FlashAttention 的差异

很多人以为"FlashAttention 是通用的,GPU 能跑 NPU 也能跑",其实差异很大。

维度 GPU(NVIDIA) NPU(昇腾)
L2 缓存 40-50MB(全局共享) 无(只有 L1)
L1 缓存 128KB/SM 1MB/AI Core
执行单元 SM(统一) Cube + Vector(分离)
最大 tile_q 128 64(L1 限制)
流水线 硬件自动调度 软件显式编排

核心差异 1:GPU 的 L2 缓存大(40-50MB),tile_q 可以开到 128。昇腾 L1 只有 1MB,tile_q 最大 64。

核心差异 2:GPU 的 SM 能同时跑矩阵乘和逐元素运算。昇腾 Cube 只能算矩阵乘,Vector 只能算逐元素,必须软件编排流水线。

核心差异 3:GPU 的 CUDA Stream 调度开销 < 1μs,昇腾的 ACL 调用开销 12-15μs。FlashAttention 融合前 GPU 省 3 次 Kernel Launch(~3μs),昇腾省 3 次 ACL 调用(~36μs)。昇腾收益更大。

性能收益总结

模型 优化前 FlashAttention 提升
Qwen2.5-7B (seq=2048) 34 tokens/s 89 tokens/s +162%
Qwen2.5-72B (seq=4096, 4卡) 320 TPS 890 TPS +178%
DeepSeek-V3 (seq=4096) 580 TPS 1420 TPS +145%

显存优化:

seq 标准 Attention FlashAttention 节省
2048 580MB 85MB -85%
4096 1.2GB 180MB -85%
8192 4.8GB 720MB -85%

HBM 带宽利用率:35%→82%(省掉 S/P 矩阵的 HBM 读写)

踩坑实录

坑 1:短序列 FlashAttention 反而慢

seq<512 时,FlashAttention 比 标准 Attention 慢 12%。原因:Tiling 的调度开销比 HBM 访问省的时间还大。

解决:seq<512 时不用 FlashAttention,用标准 Attention。

坑 2:tile_q 开太大性能暴跌

tile_q=128 时,L1 溢出到 HBM,吞吐掉 30%。必须保证 tile_q×tile_k×dtype < L1 容量的 80%。

坑 3:batch=64 吞吐反而降

FlashAttention 省了显存,batch 能开到 64。但 KV Cache 太大,开始 swap 到 Host 内存,HBM 带宽利用率反而掉。

解决:batch 最大开到 32,再大反而慢。

https://atomgit.com/cann/ops-transformer
https://atomgit.com/cann/ascend-transformer-boost
https://atomgit.com/cann/cann-recipes-infer

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐