理解 AdamW 算法:基于 Adam 的权重衰减改进及应用场景

AdamW 算法是 Adam 优化器的改进版本,专门针对权重衰减(weight decay)机制进行了优化。它在深度学习训练中广泛使用,能有效提升模型性能,特别是在处理过拟合问题时。下面我将逐步解释 AdamW 的核心概念、改进点、数学实现和应用场景,帮助您全面理解。内容基于可靠的研究(如原论文),确保真实性和准确性。


步骤 1: 回顾 Adam 优化器的基础

Adam(Adaptive Moment Estimation)是一种自适应学习率优化算法,结合了动量(Momentum)和 RMSProp 的思想。它通过计算梯度的一阶矩(均值)和二阶矩(未中心化的方差)来自适应调整每个参数的学习率。标准 Adam 的更新过程如下:

  • 设 $t$ 为时间步,$\theta_t$ 为参数,$\alpha$ 为学习率,$\beta_1$ 和 $\beta_2$ 为衰减率(通常取 $\beta_1 = 0.9$, $\beta_2 = 0.999$),$\epsilon$ 为小常数(如 $10^{-8}$)防止除零。
  • 梯度计算:$g_t = \nabla f(\theta_{t-1})$,其中 $f$ 是损失函数。
  • 一阶矩估计:$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$
  • 二阶矩估计:$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$
  • 偏差校正(针对初始偏差): $$ \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} $$
  • 参数更新: $$ \theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} $$

在标准 Adam 中,权重衰减通常通过 L2 正则化实现:即在损失函数中添加 $\frac{\lambda}{2} |\theta|^2$ 项($\lambda$ 为权重衰减系数),使梯度包含正则化贡献。但这可能导致问题,我们下一步讨论。


步骤 2: 权重衰减在 Adam 中的问题及改进需求

在标准 Adam 中,权重衰减被融入梯度计算(即 $g_t$ 包含正则化项的梯度)。这会导致两个主要问题:

  • 自适应学习率冲突:Adam 的自适应学习率(如 $\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}$)会放大或缩小权重衰减的效果,因为正则化梯度被当作普通梯度处理。这可能导致权重衰减不稳定,尤其在训练初期。
  • 次优优化:权重衰减的本质是直接约束参数大小,但通过梯度间接实现时,可能与 Adam 的自适应机制冲突,影响收敛速度和泛化性能。

AdamW("W" 代表 "Weight decay")通过将权重衰减与梯度更新分离来解决这些问题。核心改进是:权重衰减不再通过损失函数的正则化项添加,而是直接在参数更新步骤中应用。这确保了权重衰减独立于自适应学习率,更符合其物理意义。


步骤 3: AdamW 算法的核心改进

AdamW 保留了 Adam 的自适应学习率机制,但修改了更新规则:

  • 权重衰减分离:在参数更新时,直接减去权重衰减项($\alpha \lambda \theta_{t-1}$),而不是在梯度中包含它。
  • 优势
    • 更稳定:权重衰减不受自适应学习率缩放的影响,避免优化振荡。
    • 更好泛化:实验表明,AdamW 在图像分类、语言模型等任务中能获得更高的测试精度。
    • 超参数鲁棒:$\lambda$ 的设置更直观,不易受学习率影响。

AdamW 的更新公式如下(基于原论文实现):

  • 梯度计算不变:$g_t = \nabla f(\theta_{t-1})$(注意:这里 $f$ 是原始损失函数,不含正则化)。
  • 矩估计和偏差校正同 Adam: $$ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t $$ $$ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 $$ $$ \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} $$
  • 关键改进:参数更新: $$ \theta_t = \theta_{t-1} - \alpha \left( \frac{\hat{m}t}{\sqrt{\hat{v}t} + \epsilon} + \lambda \theta{t-1} \right) $$ 其中 $\lambda$ 是权重衰减系数。对比标准 Adam,AdamW 在更新中添加了 $-\alpha \lambda \theta{t-1}$ 项,这直接实现了权重衰减。

为了更清晰,这里用伪代码展示 AdamW 的实现流程(Python 风格,但非完整代码):

def adamw(params, grads, m, v, t, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, lambda_wd=0.01):
    for param, grad in zip(params, grads):
        # 更新一阶矩
        m = beta1 * m + (1 - beta1) * grad
        # 更新二阶矩
        v = beta2 * v + (1 - beta2) * grad**2
        # 偏差校正
        m_hat = m / (1 - beta1**t)
        v_hat = v / (1 - beta2**t)
        # 参数更新(包含权重衰减)
        param = param - alpha * (m_hat / (np.sqrt(v_hat) + epsilon) - alpha * lambda_wd * param
    return params, m, v

注意:实际库(如 PyTorch 的 torch.optim.AdamW)已内置此算法,用户只需设置 weight_decay 参数。


步骤 4: AdamW 的应用场景

AdamW 特别适用于需要权重衰减的深度学习任务,其优势场景包括:

  • 计算机视觉(CV):在图像分类(如 ResNet、ViT 模型)中,AdamW 能有效控制模型复杂度,减少过拟合。例如,训练 Vision Transformer 时,AdamW 常作为默认优化器。
  • 自然语言处理(NLP):在大型语言模型(如 BERT、GPT)训练中,权重衰减对泛化至关重要。AdamW 能稳定处理高维参数,提升收敛效率。
  • 其他场景
    • 当数据集较小或噪声较多时,AdamW 的正则化效果更好。
    • 在迁移学习中,AdamW 能快速适应新任务,避免灾难性遗忘。
    • 对比标准 Adam,AdamW 在 batch size 较大或学习率较高时更鲁棒。

为什么选择 AdamW?

  • 实验证据:在 ImageNet 等基准测试中,AdamW 相比 Adam 能提高 1-2% 的准确率。
  • 实践建议:如果您的任务涉及 L2 正则化或权重衰减,优先使用 AdamW。超参数设置:$\alpha$(学习率)通常取 $10^{-3}$ 到 $10^{-5}$,$\lambda$(权重衰减)取 $10^{-2}$ 到 $10^{-4}$,具体需调参。

总结

AdamW 算法通过将权重衰减从梯度计算中分离出来,直接应用于参数更新,解决了标准 Adam 中的冲突问题。这使其在深度学习训练中更高效、稳定,尤其适合需要强正则化的场景(如 CV 和 NLP)。实践中,您可以直接使用深度学习框架的优化器(如 PyTorch 的 AdamW),并关注学习率和权重衰减系数的调优。如果您有具体任务或代码问题,我可以进一步帮助分析!

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐