CANN8.5-ops-nn新融合算子-昇腾NPU推理加速又多了几把刷子

CANN 8.5 对 ops-nn 仓库做了三个重要更新:新增 RMSNorm + 残差融合算子、Linear + GELU 融合接口、以及 LayerNorm 的动态 shape 支持。前两个直接提升推理性能,第三个解决了一个长期以来的兼容性痛点。

新增:RMSNorm + 残差融合

Llama 架构每一层有两个残差连接,每个残差连接后面跟一个 RMSNorm。标准实现是两个独立 kernel:elementwise add + RMSNorm。CANN 8.5 把它们合成了一个。

import torch_npu

# CANN 8.0:两个 kernel
residual = x + sublayer_out           # kernel 1
normed = torch_npu.npu.rms_norm(residual, w, eps)  # kernel 2

# CANN 8.5:一个 kernel
normed, residual_new = torch_npu.npu.fused_add_rms_norm(
    x, sublayer_out, w, eps
)
# 同时输出归一化结果和更新后的残差

收益不只是少一次 kernel launch。融合算子还避免了残差结果的 HBM 写入——elementwise add 的结果直接在片上缓存传给 RMSNorm,32 层省约 2.1 GB 的 HBM 读写。

Atlas 800I A2 上 Llama2-7B 的 decode 延迟对比:

配置 首 token 延迟 (ms) decode 速度 (tokens/s)
分离残差+RMSNorm 78 2,840
融合 add+RMSNorm 68 3,260

decode 速度提升 15%。这个收益在 prefill 阶段不明显(prefill 是 compute-bound,不是 memory-bound),但在 decode 阶段很实在。

新增:Linear + GELU 融合

之前 ops-nn 只有 linear_activation 接口支持 ReLU 和 SiLU。GELU 因为计算更复杂(tanh 近似),CANN 8.0 不支持跟 Linear 融合。

CANN 8.5 补上了:

import torch_npu

# Linear + GELU 融合
out = torch_npu.npu.linear_activation(x, w, b, activation="gelu")

这对 BERT 类模型是刚需。BERT 的 FFN 层用 GELU 激活,之前没法跟 Linear 融合,每次都有一次额外的 HBM 读写。

性能对比(BERT-base FFN 层):

配置 延迟 (ms)
Linear + GELU 分离 0.38
融合 0.29

加速 24%。GELU 的融合收益比 SiLU 大——因为 GELU 单独计算需要 4 条 Vector 指令,融合后这些指令跟 Cube 的 MatMul 流水执行,延迟被 MatMul 覆盖了一部分。

LayerNorm 动态 shape 支持

这是一个兼容性修复。CANN 8.0 的 LayerNorm 要求归一化维度在编译时确定,推理时如果遇到不同长度的输入(比如 batch 里 padded 到不同长度),GE 需要重新编译计算图。

CANN 8.5 的 LayerNorm 支持运行时动态 shape,不需要重新编译。这对推理服务的 batch 调度特别重要——不同请求的序列长度不同,动态 shape 让你不用每个长度都编译一份计算图。

# CANN 8.0:不同 seq_len 触发重新编译
x1 = torch.randn(1, 128, 4096, device="npu")   # 编译一次
x2 = torch.randn(1, 256, 4096, device="npu")   # 又编译一次
x3 = torch.randn(1, 512, 4096, device="npu")   # 再编译一次

# CANN 8.5:同一份计算图复用
# 不需要重复编译,首次编译后所有 seq_len 都能用

首次推理的编译延迟从 ~2s 降到 ~0.5s,后续请求零编译开销。

升级建议

场景 是否需要升级
Llama 推理(decode 性能敏感) 建议升级,融合 RMSNorm 收益 15%
BERT 推理 建议升级,GELU 融合收益 24%
推理服务有变长输入 建议升级,动态 shape 免编译
只做训练 不急,训练场景上述优化影响不大
CV 模型推理 不急,以上三个更新主要服务 NLP 模型

CANN 8.5 对 ops-nn 的更新集中在"推理场景的融合补全"上。RMSNorm+残差和 Linear+GELU 是之前用户反馈最多的两个融合缺口,这次都补上了。如果你的推理服务还在 CANN 8.0,这两个更新值得你做一次升级。仓库在这里:

https://atomgit.com/cann/ops-nn

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐