MoE 模型的分布式训练有个老问题:All-to-All 通信的时间跟计算时间差不多长,NPU 一半时间在算、一半时间在等数据。ops-transformer 仓库的 MC2(Merge-Compute-Communicate)通算融合,把通信和计算重叠到同一个时间窗口里,让昇腾NPU不用再等。

通信为什么是瓶颈

先看 MoE 一个 training step 的时间分解:

1. Token Embedding + Gate 计算          ≈ 5%
2. All-to-All Scatter(token 发到 expert) ≈ 20%
3. Expert FFN 计算                      ≈ 30%
4. All-to-All Gather(结果收回来)        ≈ 20%
5. Attention + 残差 + LayerNorm         ≈ 25%

步骤 2 和 4 加起来占 40%。这 40% 的时间里 Cube 单元是空的——它在等数据搬完。

标准实现里,通信和计算是串行的:发完数据才能算,算完才能收回来。但昇腾NPU的硬件实际上有两条独立通路——HCCL 通信走 DMA 引擎,计算走 Cube+Vector 引擎。两条路互不干扰。

MC2 就是让这两条路同时工作。

流水线设计

MC2 的核心是三层流水:

第一层:Micro-batch 切分。 把一个 batch 的 token 切成 micro-batch,每个 micro-batch 独立走通信→计算→通信的流程。

第二层:通信计算重叠。 第 N 个 micro-batch 在计算时,第 N+1 个 micro-batch 在通信。两条硬件通路并行。

时间线:
Micro 0: [通信→计算→通信]
Micro 1:        [通信→计算→通信]
Micro 2:               [通信→计算→通信]
                      ↑ 通信和计算重叠

第三层:归约合并。 所有 micro-batch 的结果按原始 token 顺序拼好,跟不做切分的结果数学等价。

关键参数是 micro-batch 的大小。切太小,通信次数增多,调度开销吃掉重叠收益;切太大,重叠窗口太小。ops-transformer 的默认值是按 expert 数量均分,8 expert 场景下 micro-batch = batch_size / 8。

CANN 8.5 的跨机适配

CANN 8.0 的 MC2 只支持单机内卡间通信(HCCL 内部走 PCIe/HCCS)。跨机 All-to-All 走 RoCE 网卡,MC2 管不到——通信和计算又变成串行的。

CANN 8.5 加了 RoCE 通算融合。MC2 同时管理 HCCL 和 RoCE 两条通信链路:

单机内:NPU 0 ↔ NPU 3(HCCL/HCCS)  ← MC2 管理
跨机:  NPU 3 ↔ NPU 4(RoCE/RDMA)  ← MC2 也管了

实现上,MC2 用 HCCL 的异步通信接口发起 All-to-All,然后在通信回调里触发计算。RoCE 的 RDMA 操作天然支持异步,不需要额外适配。HCCL 的同步操作需要改成异步模式,这是 CANN 8.5 在 HCCL 层做的修改。

性能数据

Atlas 800I A2,Mixtral 8x7B:

配置 通信占比 吞吐 (tokens/s/p)
无通算融合 47% 580
MC2 单机 18% 1,420
MC2 跨机(CANN 8.5) 23% 1,050

单机场景通信占比从 47% 砍到 18%,几乎完全隐藏。跨机场景因为 RoCE 延迟比 HCCS 高,还有 23% 没藏住,但比无融合的 47% 已经好很多。

用法

MC2 在 ATB 训练框架里默认启用。单算子调用:

import torch_npu

# MC2 封装了 All-to-All + Expert 计算 + All-to-All
# 输入:token embeddings + expert 权重 + 路由索引
output = torch_npu.npu.mc2_fused_expert(
    x,                    # [batch, seq, hidden]
    weights_up,           # [num_experts, hidden, ff_dim]
    weights_gate,         # [num_experts, hidden, ff_dim]
    weights_down,         # [num_experts, ff_dim, hidden]
    expert_ids,           # [batch, seq, topk]
    group=hccl_group      # HCCL 通信组
)

需要传入 HCCL 通信组(hccl_group),MC2 通过它发起异步 All-to-All。如果你用的是 ATB,通信组自动创建,不用手动传。

踩坑

Micro-batch 太小会退化。 batch_size=4 的场景下切 8 个 micro-batch,每个只有 0.5 个样本,通信次数翻了 8 倍,调度开销把重叠收益全吃光了。MC2 的收益在 batch_size ≥ 16 时才明显。

RoCE 网卡需要开启 RDMA。 跨机 MC2 依赖 RDMA 的零拷贝特性。如果 RoCE 网卡只配了 TCP 模式,MC2 会 fallback 到同步通信,性能回到无融合水平。检查方式:hccn_tool -i eth0 -roce -s,确认 RDMA 状态是 enabled。


如果你的 MoE 训练通信占比超过 30%,MC2 是最直接的优化手段。单机效果最好,跨机需要 CANN 8.5 + RDMA 环境。仓库在这里:

https://atomgit.com/cann/ops-transformer

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐