RMSNorm 算子:让 LayerNorm 不再拖后腿
摘要: 在昇腾NPU上运行LLaMA时发现RMSNorm(替代LayerNorm)占8%推理时间,因其标准实现未优化。RMSNorm相比LayerNorm计算量少30%(省去均值计算和β参数),但标准实现存在两次HBM读写和FP16数值不稳定问题。ops-transformer通过融合Kernel(单次HBM访问)、FP32累加和多核并行,将延迟降低67%(从1.8ms至0.6ms),占比从8.2
第一次在昇腾NPU上跑 LLaMA,profiling 出来发现 LayerNorm 占了 8% 的推理时间——明明是个"小算子",怎么比 MatMul 还慢?
后来才发现:LLaMA 用的不是 LayerNorm,是 RMSNorm(Root Mean Square Layer Normalization),但标准实现没优化到位。
LayerNorm vs RMSNorm:差在哪?
LayerNorm(标准 Transformer 用):
LN(x) = γ * (x - μ) / √(σ² + ε) + β
要算:均值 μ、方差 σ²、两个可学习参数 γ 和 β。
RMSNorm(LLaMA / LLaMA-2 / Mistral 都用):
RMSNorm(x) = (x / RMS(x)) * γ
其中 RMS(x) = √(mean(x²) + ε)
只算:均方根 RMS(x)、一个可学习参数 γ。
少了什么: 不用算均值 μ,不用减均值,不用学 β 参数。
结果: RMSNorm 的计算量比 LayerNorm 少约 30%,但对精度的影
响极小(LLaMA 全程用 RMSNorm,没人觉得它精度不够)。
标准 RMSNorm 实现的问题
看起来简单(就一个均方根 + 除法),但标准实现有几个坑:
1. 需要两次 HBM 读写
标准实现:
# 第一次读:x 从 HBM 读进来
x = x.cuda()
# 算 RMS:需要把 x 的所有元素读一遍(第二次读 HBM)
rms = torch.sqrt(torch.mean(x ** 2) + eps)
# 除法:结果写回 HBM(第一次写)
out = x / rms * gamma
问题: x 被读了两遍(一次算 mean,一次做除法),HBM 带宽浪费。
2. 数值稳定性(FP16 下容易爆)
RMS 的计算是 sqrt(mean(x²)),如果 x² 很大,FP16 会上溢(最大值 65504,平方之后直接 inf)。
标准实现里经常看到:
x = x.to(torch.float32) # 先转 FP32
rms = torch.sqrt(torch.mean(x ** 2) + eps)
out = (x / rms * gamma).to(torch.float16) # 再转回 FP16
问题: 频繁转 FP32 ↔ FP16,慢。
ops-transformer 里的 RMSNorm 算子优化
1. 合并成一次 HBM 读写(Fused Kernel)
ops-transformer 的实现里,RMSNorm 整个算子只有一个 Kernel,做完所有事情:
__aicore__ void FusedRMSNorm(
AscendC::GlobalTensor<float> &x, // 输入(HBM)
AscendC::GlobalTensor<float> &out, // 输出(HBM)
AscendC::LocalTensor<float> &gamma, // 可学习参数(存在 UB)
int hiddenDim
) {
// 第一步:把 x 的一个 Tile 读进 UB(只读一次)
auto ubX = ctx.AllocTensor<float>(/*Tile 大小*/);
AscendC::DataCopy(ubX, x[offset], tileSize);
// 第二步:在 UB 里算 RMS(不读 HBM)
auto ubX2 = ubX * ubX; // x²(UB 内)
float meanX2 = AscendC::Mean(ubX2); // mean(x²)(UB 内)
float rms = sqrt(meanX2 + eps); // RMS(UB 内)
// 第三步:在 UB 里算除法 + 乘 gamma(不读 HBM)
ubX = ubX / rms * gamma; // 结果在 UB 里
// 第四步:写回 HBM(只写一次)
AscendC::DataCopy(out[offset], ubX, tileSize);
}
关键: x 只从 HBM 读一次,结果只写回 HBM 一次。中间计算全在 UB(片上内存)里完成。
2. FP16 的数值稳定性(不用转 FP32)
昇腾NPU 的 Vector 核支持 FP16 的 “加和到 FP32”(类似 TensorCore 的累积方式)。
ops-transformer 的 RMSNorm 里:
x²用 FP16 算(快)- 累加
mean(x²)的时候,累加器用 FP32(不溢出) - 最后
sqrt和除法也在 FP32 累加器里做
结果: 不用显式转 FP32 ↔ FP16,数值稳定性够,速度快。
3. 多核并行(按 Batch 维度切分)
RMSNorm 的计算是 逐 token 独立的(每个 token 的 RMS 只和自己的 hidden 维度有关,和别的 token 无关)。
ops-transformer 里把 Batch × SeqLen 维度切分到多个 AI Core 上:
- 每个 AI Core 负责 4-8 个 token 的 RMSNorm
- AI Core 之间不需要通信(每个 token 独立)
- 线性加速比(核心数翻倍,延迟减半)
实际收益(LLaMA-2 7B,Atlas 300I Duo,Batch=16)
| 配置 | RMSNorm 延迟 (ms) | 占总推理时间比例 | 数值稳定性(FP16) |
|---|---|---|---|
| 标准 RMSNorm(PyTorch) | ~1.8 | 8.2% | 偶尔 NaN(大模型) |
| + ops-transformer 优化 | ~0.6 | 2.7% | 无 NaN(FP32 累加) |
| 提升幅度 | -67% | -5.5pp | ✅ 稳定 |
代码示例(PyTorch,调用 RMSNorm)
import torch
import torch_npu
# LLaMA-2 7B 的 RMSNorm 配置
rms_norm_config = {
"hidden_size": 4096,
"eps": 1e-5,
"elementwise_affine": True, # 有 gamma 参数
}
# 在昇腾NPU上,RMSNorm 底层走的是 ops-transformer 的 Fused Kernel
# 不需要额外配置,CANN 8.0+ 自动识别
x = torch.randn(16, 4096).npu() # [Batch, Hidden]
gamma = torch.randn(4096).npu() # [Hidden]
# 调用 RMSNorm(底层是 ops-transformer 的 FusedRMSNorm Kernel)
out = F.rms_norm(x, (4096,), gamma, eps=1e-5)
# 上面的调用在昇腾NPU上走的是:
# ops-transformer 的 FusedRMSNorm(一次 HBM 读写 + FP32 累加)
一个容易踩的坑
RMSNorm 的 eps 不能设太大。
比如设 eps = 1e-3(比默认的 1e-5 大 100 倍),会导致:
- 梯度变小(RMS 的分母变大)
- 训练不稳定(尤其是大模型,层数深,梯度累积误差大)
经验值:
eps = 1e-5(LLaMA 的默认配置)✅eps = 1e-6(可以,但收益很小)✅eps ≥ 1e-4(不推荐)❌
如果你想在自己的模型里用 RMSNorm,或者想把现有 LayerNorm 改成 RMSNorm,去 ops-transformer 的 ops/norm/ 目录:
https://atomgit.com/cann/ops-transformer
里面有:
rms_norm_kernel.cpp— RMSNorm 的 Ascend C 实现(Fused Kernel)rms_norm_vs_layernorm.py— RMSNorm vs LayerNorm 的精度/速度对比examples/rms_norm_profiling.py— 跑这个脚本看 RMSNorm 的 Timeline
一句话总结:RMSNorm 不是"新算法",是"让 LayerNorm 少算一点"——少算均值、少一个可学习参数,把两次 HBM 读写合成一次,速度和稳定性都上去了。
昇腾NPU 上跑 LLaMA,RMSNorm 的优化在 Batch 大的时候(≥16)收益更明显——因为 HBM 带宽成为瓶颈,Fused Kernel 减少 HBM 访问的优势就出来了。
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐

所有评论(0)