CANN-ops-math高精度Softmax-昇腾NPU上float32归一化为什么不能省

大模型推理大多用 float16 跑,唯独 Softmax 这一步必须升到 float32。不是开发者保守,是 float16 的指数运算真的会溢出。ops-math 的高精度 Softmax 在昇腾NPU上做了什么,为什么融合算子里也得留这一步,这篇讲清楚。

float16 Softmax 的溢出问题

float16 的最大值是 65504。Softmax 的计算过程:

1. 减最大值:x_i - max(x)  → 所有值 ≤ 0
2. 指数运算:exp(x_i - max(x))  → 所有值 ≤ 1
3. 归一化:exp_i / sum(exp)

第 1 步保证指数运算不溢出。但如果跳过第 1 步直接算 exp,float16 下只要 x > 11 就溢出(exp(11) ≈ 59874,接近 65504)。

问题出在减最大值本身也用 float16。假设序列长度 4096,logits 的方差约 8(初始化后的典型值),最大值可能到 30-40。float16 减法在大值附近的精度约 0.06,减完的结果在 [-40, 0] 区间,精度只有 0.06。这个误差在 exp 之后被放大——exp(-0.06) / exp(0) ≈ 0.94,4% 的相对误差。

float32 减法的精度约 1e-7,exp 后的误差在 1e-6 量级。差了四个数量级。

ops-math 的实现

ops-math 的 Softmax 标准流程:

1. 输入 float16 → Cast → float32
2. float32 减最大值
3. float32 指数运算
4. float32 归一化
5. 输出 float32 → Cast → float16

两次 Cast 看起来浪费,但昇腾NPU的 Vector 单元做 Cast 只需 1-2 个时钟周期,跟 exp 运算相比可以忽略。

FlashAttention 里的 Softmax

FlashAttention 把 Softmax 融合进了 Attention kernel,但 Softmax 的 float32 精度保持没有省:

FlashAttention kernel 内部:
  Q·K^T → float16 结果
  → Cast float32
  → 减最大值(float32)
  → 指数运算(float32)
  → 归一化(float32)
  → Cast float16
  → 乘 V(float16)

这也是为什么 FlashAttention 的显存占用比标准 Attention 少——中间的 Softmax 结果不需要存到 HBM,在片上缓存完成 float32 计算后直接 Cast 回 float16 继续算。

bf16 能不能绕过这个问题

bf16 的指数范围跟 float32 一样大(±3.4×10^38),但精度只有 2 位十进制。bf16 的 Softmax 问题不是溢出,是精度:

bf16 减法精度:约 0.02(在 [-40, 0] 区间)
fp32 减法精度:约 1e-7

0.02 的误差在 exp 后约 2% 的相对误差。对于 Attention 权重来说,2% 的误差会直接影响 token 的关注度排序——Softmax 本来就是为了区分"该看哪里",2% 的噪声可能让模型关注到错误的 token。

所以 bf16 也得升 fp32 做 Softmax。结论:不管输入是 fp16 还是 bf16,Softmax 必须用 fp32 算。

在线 Softmax vs 两遍 Softmax

标准 Softmax 是两遍的:第一遍找最大值,第二遍算指数和归一化。FlashAttention 用的是在线 Softmax(Online Softmax),一遍完成:

在线 Softmax:
  逐块处理,维护运行中的最大值 m 和累积和 l
  每处理一个新块:
    m_new = max(m_old, max(new_block))
    l_new = l_old * exp(m_old - m_new) + sum(exp(new_block - m_new))
    O_new = (O_old * l_old * exp(m_old - m_new) + new_block_result) / l_new

在线 Softmax 的核心优势是只需要一遍扫描,配合 FlashAttention 的分块计算正好。但数学上它引入了额外的乘法和除法来修正之前的累积结果——这些修正也必须在 float32 下完成。

性能影响

float32 Softmax 的额外开销:

步骤 float16 版本 float32 版本
减最大值 0.008ms 0.008ms
指数运算 0.012ms 0.014ms
归一化 0.006ms 0.006ms
Cast × 2 0 0.002ms
总计 0.026ms 0.030ms

float32 只慢 15%,但精度提升四个数量级。这笔交易怎么算都值。


Softmax 用 float32 不是保守,是数学上的必然。float16 的指数精度不够,bf16 也一样。ops-math 的高精度 Softmax 在融合算子内部自动处理,用户不需要手动 Cast。但如果你自己写 Softmax 相关的算子,记住这个规则。仓库在这里:

https://atomgit.com/cann/ops-math

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐