请添加图片描述

前言

模型训练慢,不一定是算力不够。下面 10 个技巧,按收益从大到小排列,大部分改几行代码就能生效。


1. 开启混合精度训练

收益:速度提升 1.5-2 倍,精度几乎无损失

from torch_npu.contrib import transfer_to_npu

model = Model().npu()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 开启自动混合精度
scaler = torch.npu.amp.GradScaler()
for data, target in dataloader:
    with torch.npu.amp.autocast():
        output = model(data)
        loss = criterion(output, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

原理:大部分算子用 FP16 计算(速度是 FP32 的 2 倍),少量对精度敏感的算子(LayerNorm、Softmax)保留 FP32。昇腾 910 的 Cube Unit 做 FP16 矩阵乘的吞吐量是 FP32 的 4 倍,所以混合精度的收益在昇腾上比 GPU 更明显。


2. 用 DDP 而不是 DataParallel

收益:多卡线性加速比从 0.5 提到 0.9

import torch.distributed as dist

dist.init_process_group(backend="hccl")
model = torch.nn.parallel.DistributedDataParallel(
    model, device_ids=[local_rank]
)

DataParallel 是单进程多线程,受 GIL 限制,4 卡加速比只有 2 倍。DDP 是多进程,每张卡一个进程,4 卡加速比 3.6 倍。注意 backendhccl,不是 nccl


3. 梯度累积代替大 batch

收益:避免 OOM,显存占用降低 50%

accumulation_steps = 4
optimizer.zero_grad()

for i, (data, target) in enumerate(dataloader):
    with torch.npu.amp.autocast():
        output = model(data)
        loss = criterion(output, target) / accumulation_steps
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

实际 batch = batch_per_gpu × accumulation_steps × num_gpus。效果和大 batch 等价,但每次前向传播只占 1/4 的显存。


4. 用 torch.compile 加速图优化

收益:速度提升 10-20%

model = torch.compile(model, backend="npu")

昇腾的 torch.compile 后端会把 PyTorch 的动态图转成 GE 的静态图,触发算子融合和内存优化。首次执行会编译(约 30 秒),之后重复执行不需要重编译。


5. 数据预取

收益:训练速度提升 5-15%(取决于数据加载是否是瓶颈)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    num_workers=8,      # 8 个进程并行加载
    pin_memory=True,    # 锁页内存,加速 CPU→NPU 搬运
    prefetch_factor=4   # 预取 4 个 batch
)

pin_memory=True 让数据分配在锁页内存里,CPU 到 NPU 的搬运走 DMA,不需要 CPU 参与。num_workers=8 让数据加载和模型计算重叠。


6. FlashAttention 加速 Transformer

收益:序列 >1024 时,Attention 计算快 3 倍,显存省 80%

from ops_transformer import flash_attention

# 替换标准 Attention
output = flash_attention(query=Q, key=K, value=V)

不需要手动修改模型代码。ascend-transformer-boost 提供了自动替换功能:

import ascend_transformer_boost
model = ascend_transformer_boost.optimize(model)

自动把模型中的标准 Attention 替换成 FlashAttention,同时优化 KV Cache 管理。


7. 梯度压缩(HCCL 原生支持)

收益:通信量减半,多卡训练提速 10-15%

# HCCL 后端自动支持 FP16 梯度压缩
dist.init_process_group(backend="hccl")

NCCL 没有原生梯度压缩,需要手动用 ddp_comm_hooks。HCCL 在库层面支持,不需要额外代码。


8. 优化器状态卸载到 CPU

收益:大模型训练显存节省 30-40%

from torch.distributed.optim import ZeroRedundancyOptimizer

optimizer = ZeroRedundancyOptimizer(
    model.parameters(),
    optimizer_class=torch.optim.AdamW,
    lr=1e-4
)

AdamW 的优化器状态(m 和 v)占的显存是模型参数的 2 倍。卸载到 CPU 之后,LLaMA-7B 的训练显存从 42GB 降到 26GB,单卡 910 就能跑。


9. 检查点只存权重不存梯度

收益:checkpoint 体积缩小 3 倍,保存速度提升 5 倍

# 不要这样存
torch.save({
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),  # 体积是权重的 2-3 倍
}, 'checkpoint.pt')

# 这样存
torch.save({
    'model': model.state_dict(),  # 只存权重
}, 'checkpoint.pt')

优化器状态可以从权重重新初始化,不需要存。如果训练中断,从最近的权重重新开始即可,损失最多一个 epoch 的进度。


10. 关闭不必要的同步

收益:训练速度提升 3-5%

# 不要在训练循环里加这个
torch.npu.synchronize()  # 等待所有 NPU 操作完成

# 也不要频繁打印 loss
if i % 1 == 0:  # 每个 step 都打印
    print(f"loss: {loss.item()}")  # .item() 会触发同步

.item().cpu() 都会触发 CPU-NPU 同步,破坏异步执行的流水线。打印频率改成每 100 步一次,或者用 TensorBoard 的 add_scalar(异步写入)。


技巧收益汇总

技巧 改动量 速度提升 适用场景
混合精度 5 行代码 1.5-2x 所有模型
DDP 替换 DP 10 行代码 1.5-3x 多卡训练
梯度累积 8 行代码 避 OOM 大 batch 训练
torch.compile 1 行代码 10-20% 静态图友好模型
数据预取 3 行配置 5-15% IO 密集型
FlashAttention 1 行代码 3x(Attention) Transformer
梯度压缩 0 行代码 10-15% 多卡训练
优化器卸载 3 行代码 省 30% 显存 大模型
轻量检查点 2 行代码 5x 保存速度 长时间训练
减少同步 删代码 3-5% 所有训练

参考资源

  • 昇腾训练最佳实践:https://www.hiascend.com/document/detail/zh/CANN/
  • torch_npu API 文档:https://gitee.com/ascend/pytorch
  • ascend-transformer-boost:https://atomgit.com/cann/ascend-transformer-boost
  • HCCL 通信优化指南:https://www.hiascend.com/document/detail/zh/CANN/

总结

10 个技巧里,收益最大的是混合精度和 DDP——几乎适用所有场景,改动量小,效果显著。FlashAttention 是 Transformer 模型的标配。梯度压缩和优化器卸载解决的是大模型训练的显存瓶颈。减少同步和轻量检查点是容易被忽略的细节,改了就有收益。不需要一次全用上,按收益排序,从第一个开始,逐个加上去。

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐