请添加图片描述

MoE 的 Expert Parallel 需要全互连通信——每个 token 发给它路由到的专家所在的卡,再收回来。这个 All-to-All 通信在 8 卡 MoE 上能占 30% 的推理时间。MC2(Merge-Communicate-Split)把通信和计算融合在一起,在等数据的时候不闲着。

All-to-All 通信的瓶颈

先说清楚 All-to-All 是什么。

在 Expert Parallel(EP)模式下,每个专家分布在不同的 GPU/NPU 上。LLaMA-MoE 的 8 个专家分布在 8 张卡上,每张卡有一个专家。

每个 token 要经过两步通信:

第一步:发送 token 到它被路由到的专家所在的卡。比如 token A 被路由到专家 3,专家 3 在卡 3 上,所以 token A 要从卡 0 发送到卡 3。

第二步:专家处理完 token,结果要发回原卡。

8 张卡全互连,每张卡同时发送和接收,总通信量是 8 × tensor_size × 平均跳数(跳数取决于物理拓扑)。

All-to-All 的瓶颈是延迟而不是带宽。小消息多(每个 token 只有几百个浮点数),每条消息都需要握手、同步、网络排队。这些开销累积起来,通信时间占 30%。

MC2 的核心思路

MC2 的核心观察是:等数据的时候不闲着。

标准流程是:先发完所有数据(等待),等所有数据到齐了,再开始计算(处理)。

MC2 的流程是:发第一个专家的数据,同时准备第二个专家的输入;等第一个专家的结果时,同时发第三个专家的数据。通信和计算完全流水线。

MC2 拆成三个阶段:

Merge 阶段:把要发送给同一个专家的数据打包在一起。不同 token 路由到不同的专家,Merge 阶段把它们按目标专家分组,减少通信次数。

Communicate 阶段:用 HCCL 的 All-to-All 发送数据。MC2 把 All-to-All 拆成多个小批次,和计算流水并行。

Split 阶段:收到的数据按来源卡分组,把结果分发给各自的原卡。

Ascend C 实现

// MC2 融合算子的 Ascend C 核心逻辑

// MC2 的关键设计:让通信和计算 overlap
// 每张卡维护一个状态机,管理 8 个专家的处理状态

// 专家处理状态枚举
enum ExpertState {
    WAITING_INPUT,    // 等待输入数据
    PROCESSING,      // 正在处理(计算中)
    SENDING_OUTPUT,  // 正在发送输出
    DONE             // 处理完成
};

// MC2 的主循环:轮询 8 个专家的状态,调度计算和通信
__aicore__ void MC2Kernel(GM_ADDR local_expert_output, ...) {
    ExpertState states[NUM_EXPERTS] = {WAITING_INPUT};
    bool all_done = false;
    
    while (!all_done) {
        // 遍历每个专家的状态
        for (int e = 0; e < NUM_EXPERTS; ++e) {
            switch (states[e]) {
                
            case WAITING_INPUT:
                // 检查是否有数据到达(通过 HBM 标志位判断)
                if (CheckDataArrived(e)) {
                    // 数据到达,开始处理
                    LoadExpertInput(e);  // 从 HBM 读到 UB
                    states[e] = PROCESSING;
                }
                break;
                
            case PROCESSING:
                // 专家计算(可能需要多个迭代)
                if (IsComputeDone(e)) {
                    // 计算完成,准备发送
                    states[e] = SENDING_OUTPUT;
                }
                break;
                
            case SENDING_OUTPUT:
                // 发送输出(异步的,数据写入 HBM 缓冲区,由 driver 处理网络传输)
                WriteOutputToHBM(e);
                NotifyPeerComplete(e);  // 发信号给对端
                states[e] = DONE;
                break;
                
            case DONE:
                // 这个专家的处理已完成
                break;
            }
        }
        
        // 检查是否所有专家都完成了
        all_done = CheckAllDone(states);
        
        // 给 HCCL 通信线程一个机会往前推进
        YieldToHCCL();  // 让 driver 处理已到达的网络数据
    }
}

MC2 的关键实现细节是 HCCL 的异步执行。All-to-All 的发送操作是异步的——调用 HCCL 的 All-to-All 接口只是发起传输,实际数据还在搬运。MC2 在发起 All-to-All 后,不阻塞等结果,而是切换到其他专家的计算。等 All-to-All 完成(通过 event 通知),再切换回来处理结果。

这个实现依赖昇腾 runtime 的异步通信 API 和 Event 同步机制。ops-transformer 仓的 moe_mc2_fusion.cpp 里有完整的实现。

MC2 vs 标准 All-to-All

标准 All-to-All 的时间线:T0 发起通信 → T1 等待所有数据到达 → T2 全部数据到达,开始计算 → T3 计算完成。总延迟 = (T1-T0) + (T3-T2)。

MC2 的时间线:T0 发起第 1 批通信,同时开始准备第 2 批数据 → T1 第 1 批数据到达,开始处理,同时发起第 2 批通信 → T2 第 1 批处理完成,开始发送,同时处理第 2 批数据 → … → T3 全部完成。总延迟 = T1-T0 + T3-T2,但 T3-T2 因为 overlap 大幅缩短。

MC2 的收益取决于通信和计算的比率。如果专家计算很快(专家参数量小),通信时间占比高,MC2 收益就大。如果专家计算很慢(专家参数量大,比如 Mixtral 的 experts=8×7B),通信时间占比低,MC2 收益就小。

实测数据(Mixtral-8×7B,EP=8):MC2 让 All-to-All 阶段的延迟从 25ms 降到 18ms,节省约 28%。

MC2 的收益随 EP 规模增长——EP=2 时,All-to-All 的总数据量和通信距离都小,MC2 收益只有 5-8%。EP=8 时,收益达到 25-30%。

https://atomgit.com/cann/ops-transformer

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐