请添加图片描述

文章目录

前言

在深度学习模型的部署流程中昇腾 CANN(Compute Architecture for Neural Networks)承担着将用户编写的计算图转换为可在昇腾 NPU 上高效执行的指令序列的核心职责。其中的 GE(Graph Engine)模块是图编译链路的关键入口,负责对计算图进行多轮语义-preserving 优化。常量折叠(Constant Folding)与死代码消除(Dead Code Elimination,简称 DCE)是 GE 图优化阶段中最基础也最有效的两类变换——它们在编译期完成原本需要逐迭代执行的计算,从而显著减少运行时开销。本文将围绕这两个优化技术的原理、实现机制、协同配合以及实战调优展开系统性的剖析,帮助读者建立从概念到代码的完整认知。

一、编译优化的基本思想

1.1 编译期优化 vs 运行时优化

深度学习框架在模型执行时,默认以即时执行(Eager Execution)模式运行:每个算子按照其出现的顺序逐个调用,数据以张量的形式在算子之间流动。这种模式灵活性高,但每一次运行时都要重复完成类型检查、形状推导、调度决策等固定开销。如果一个算子的输入在模型构建阶段就已经完全确定(例如权重常数、固定的 reshape 参数),那么在每次前向传播中重复执行这些计算就是一种纯粹的浪费。

编译期优化的核心思想是把能提前做的事提前做。编译器在静态分析阶段收集图中所有可用信息,用这些信息在模型实际运行前完成计算、化简图结构,从而让运行时的每一步都专注于真正需要依赖运行时数据的工作。编译期优化可以做到的事情包括但不限于:将常量表达式求值后直接替换为结果、删除永远不会执行的分支、合并冗余的数据搬运操作、以及识别可以融合的小算子组合。

1.2 为什么图级优化比算子级优化更重要

单算子实现(例如 Conv2d 的 CUDA kernel 或 NPU Tensor Processor 的微指令)关注的是如何在给定输入形状和数据类型下最大化硬件利用率。而图级优化关注的是整个计算图的结构是否合理:一个算子的输出是否真的需要返回到内存,还是可以直接作为下一个算子的输入?一个 Shape 算子是否真的需要在运行时反复查询,还是可以直接在编译期推断?这些决策直接影响内存占用、kernel 启动次数以及数据搬运量。对于昇腾 NPU 这类专用加速器来说,减少 host-device 数据传输次数和减少图节点数量往往是提升端到端性能的关键。

GE 模块的优化 pipeline 设计正是围绕这一思想展开的:它首先对图进行结构层面的简化(常量折叠、死代码消除),再进行算子融合(融合多个小算子为一个大算子),最后做内存布局优化(数据格式转换的合并)。本文聚焦的前两个阶段构成了整条优化链的基础。

二、常量折叠的原理

2.1 什么是常量折叠

常量折叠(Constant Folding)是一种在编译期将只包含常量值的表达式直接求值,并用结果常量节点替换该表达式的优化技术。替换后,原来的算子节点(及其输入依赖)如果不再被其他节点使用,就可以进一步被消除。

举一个最直观的例子。假设模型中存在如下两个连续算子:

Const(值为3.0) → Mul(x, 3.0)

如果 x 在某次具体推理中是常量 2.0,则整个表达式 2.0 * 3.0 = 6.0 可以在编译期确定结果。GE 在检测到 Mul 算子的两个输入都为常量后,会在图上生成一个新的常量节点 6.0,删除原来的 Mul 节点,并用新节点替换所有对原 Mul 输出的引用。

2.2 表达式化简与代数化简

常量折叠不仅仅是数值计算,它还包含一系列代数化简规则。例如:

  • x + 0 → x
  • x * 1 → x
  • x - x → 0
  • Reshape(Const(tensor), target_shape) → 直接将 Reshape 的输出记录为常量形状

这些化简在编译期完成,可以显著削减图中的冗余节点。GE 内部的常量折叠引擎会维护一个规则表,对常见的二元运算、一元运算和变换算子逐条匹配。匹配成功后,生成等价的新常量节点并执行替换。

2.3 折叠决策边界

并非所有看起来"可以折叠"的算子都应该被折叠。GE 在折叠决策时需要综合考虑以下因素:

精度保留。如果原始算子以 float32 计算,而折叠后生成的常量以 float16 存储,则可能在后续与其他 float32 算子混合使用时引发隐式的精度转换。GE 在实现中会对这种精度变化进行评估,如果变化幅度超出用户指定的容忍阈值,则放弃折叠。

副作用边界。某些算子在执行时会产生外部可见的效应,例如模型参数的原地更新、随机数生成器的状态变更、打印节点的日志输出等。这类算子的输出即使输入全为常量,也不能简单地用编译期结果替换——因为替换会绕过算子的副作用逻辑。GE 维护了一张"副作用算子"白名单,对于在此名单上的算子强制跳过折叠。

运行时形状依赖。如果一个算子的输出形状依赖于输入张量的具体值(例如 Gather 算子中索引张量的内容),则即使索引张量是常量,GE 也需要确认形状推导在整个图上下文中是否闭合。如果存在其他算子依赖 Gather 的动态形状信息,折叠操作必须保守处理。

算子融合兼容性。GE 的常量折叠模块在决策时还会参考后续融合阶段的融合模式表。如果当前节点组合在未来某个融合阶段极可能被融合为一个新的复合算子,则优先保留节点结构以免干扰融合决策。

三、常量传播与跨算子常量传播

3.1 常量传播的基本概念

常量传播(Constant Propagation)是常量折叠的自然扩展。两者的区别在于:常量折叠处理的是"输入本身是常量"的算子,而常量传播处理的是"通过一个常量输入算子,可以推导出另一个输入也是常量"的情形。

例如,在如下图中:

ConstA(shape=[1, 128]) → Shape → Gather(ConstB, indices)

如果 ConstB 是常量张量,则 Shape 算子虽然本身不是常量节点,但其输出形状可以根据 ConstA 的 shape 属性在编译期完全确定。GE 的常量传播引擎会追踪这种间接依赖:当 ConstA 为常量时,推导出 Shape 的输出也是编译期可知的常量;进而 Gather 的索引输入为常量时,整个 Gather 结果也变成可折叠的。

3.2 数据流分析框架

常量传播依赖于数据流分析(Data Flow Analysis)框架。GE 在图优化的早期阶段会构建数据流图(Data Flow Graph),将图中的算子节点组织为有向无环图(DAG),并对每个节点的输入输出类型进行标注。

常量传播的数据流分析以如下方式工作:

  1. 初始化阶段:遍历图中所有 source 节点(如 ConstParameter),将它们的输出标记为"可能是常量"(May-Const)。
  2. 迭代阶段:对图进行拓扑排序,对于每个算子节点,检查其所有输入是否已确定为常量。如果所有输入都是常量,则将该节点的输出标记为常量,并执行折叠。
  3. 收敛阶段:迭代直到没有新的常量节点产生为止。

由于深度学习计算图天然是有向无环结构(单输入单输出算子之间不存在环),数据流分析的迭代过程必然在有限步数内收敛。GE 在实现中使用了高效的拓扑排序算法,并在每个迭代轮次中批量处理可折叠的节点组合以减少遍历开销。

3.3 跨算子常量传播的挑战

跨算子常量传播面临一个微妙的问题:当一个算子的输出被多个后续算子引用时,如果其中一个后续算子触发了常量传播从而删除了对上游节点的引用,可能会导致其他后续算子失去输入。这在 GE 中通过引用计数(Reference Counting)机制来解决——每个节点维护一个引用计数器,只有当引用计数降为零时,节点才被标记为可删除。

另一个挑战是控制流。昇腾 CANN 的计算图虽然在大多数情况下是纯函数式的,但在条件分支(If 算子)和循环(While 算子)中存在多条执行路径。GE 对控制流节点的处理采用区间分析(Region Analysis)方法:首先分析每个分支的常量属性,然后在满足特定条件时(例如两个分支的常量折叠结果相同)对控制流节点本身进行简化。

四、死代码消除(DCE)原理

4.1 什么是死代码

死代码(Dead Code)是指在图的语义下永远不会参与最终输出的节点及其传递依赖。这些代码在运行时不会产生任何可见的效果,却占用了图优化阶段的处理资源和运行时的调度资源。常见类型的死代码包括:

  • 没有任何后续算子引用的节点
  • 输出被完全丢弃的调试算子(如 PrintAssert
  • 条件分支中永远不会被选中的分支内的所有节点
  • 由于常量折叠或代数化简产生的冗余中间节点

4.2 无副作用节点的识别

死代码消除的一个关键前置问题是判断一个节点是否具有副作用。副作用的定义是:节点的执行除了产生输出张量外,还会对外部状态产生可观测的影响。

在昇腾 CANN 的图语义中,以下类型的算子被认为具有副作用:

  • 权重更新类算子(如 SGD 更新、BatchNorm 的移动均值/方差更新)
  • 随机数生成算子(RandomNormalRandomUniform
  • 输入输出类算子(DataPrintDumpTensor
  • 状态读写算子(AssignApplyMomentum 的某些变体)

对于这些具有副作用的节点,即使它们的输出没有被任何下游算子使用,也不能直接删除。GE 在 DCE 阶段维护了一张副作用算子分类表,并在遍历时对这类节点进行特殊标记:保留节点本体,但可以移除那些确实没有被引用的输出边。

4.3 引用计数分析

GE 的 DCE 实现采用了引用计数(Reference Counting)与反向拓扑遍历相结合的两轮算法。

第一轮(反向遍历):从图的输出节点反向遍历,追踪所有可达节点。遍历过程中对每个经过的节点递增其引用计数(或者在更精细的实现中使用位图标记可达性)。遍历完成后,所有未被标记为可达的节点都是死代码。

第二轮(删除与级联):删除所有死代码节点。在删除过程中,如果一个节点的删除导致其上游节点的引用计数降为零,则递归地将该上游节点也标记为死代码并删除。这一步被称为"级联死代码消除"。

下面的伪代码展示了 GE 中 DCE 的核心逻辑:

// GE DCE 核心算法伪代码
void GraphDCE::EliminateDeadNodes(GeGraph& graph) {
    // 第一轮:反向可达性标记
    std::unordered_set<GeNode*> reachable;
    std::queue<GeNode*> reverse_bfs;
    for (auto& output_node : graph.GetOutputNodes()) {
        reverse_bfs.push(output_node);
    }
    while (!reverse_bfs.empty()) {
        auto node = reverse_bfs.front();
        reverse_bfs.pop();
        if (reachable.count(node)) continue;
        reachable.insert(node);
        for (auto& input_node : node->GetInputs()) {
            reverse_bfs.push(input_node);
        }
    }

    // 第二轮:删除所有不可达节点
    for (auto& node : graph.GetAllNodes()) {
        if (!reachable.count(node)) {
            // 级联删除:检查上游是否也变成死代码
            RemoveNodeAndCascade(graph, node);
        }
    }
}

void GraphDCE::RemoveNodeAndCascade(GeGraph& graph, GeNode* node) {
    // 递减上游节点的引用计数
    for (auto& input_node : node->GetInputs()) {
        input_node->DecrementRefCount();
        if (input_node->GetRefCount() == 0 && !HasSideEffect(input_node)) {
            RemoveNodeAndCascade(graph, input_node);
        }
    }
    graph.RemoveNode(node);
}

4.4 条件分支的死代码消除

当条件分支算子(IfCond)的条件输入是编译期已知的常量时,整个分支的语义可以直接化简:如果条件为 True,则整个 If 节点替换为其 Then 分支;如果条件为 False,则替换为 Else 分支。GE 在 DCE 之前会先运行一个常量传播步骤来识别这类可化简的控制流,从而在 DCE 阶段进一步清理被舍弃分支中的所有死代码。

五、GE 图优化中的完整实现

5.1 优化 Pass 流水线

GE 的图优化以 Pass(优化轮次)为单位组织。每个 Pass 完成一类特定的图变换,多个 Pass 按依赖顺序串联成完整的优化流水线。与常量折叠和死代码消除相关的 Pass 排列顺序如下:

1. 输入规范化(Input Canonicalization)
2. 常量折叠(Constant Folding Pass)
3. 代数化简(Algebraic Simplification Pass)
4. 常量传播(Constant Propagation Pass)
5. 死代码消除(Dead Code Elimination Pass)
6. 节点替换(Node Substitution Pass)
7. 算子融合(Operator Fusion Pass)
8. 内存优化(Memory Layout Optimization Pass)

常量折叠在代数化简之后执行,是因为代数化简可能产生新的可折叠机会(例如 x - x → 0 之后,0 * y → 0)。死代码消除放在所有其他结构性优化之后执行,是因为其他 Pass 可能会产生新的死代码节点——只有在所有结构性变换完成后,才能得到最精简的图。

5.2 节点替换与等价变换

节点替换(Node Substitution)是 GE 中将一个节点替换为另一个语义等价节点的核心机制。在常量折叠和常量传播的过程中,GE 并不是直接修改原始节点的属性,而是通过节点替换 Pass 生成新的节点并更新图结构。这种设计有两个好处:它保持了 Pass 之间的隔离性(每个 Pass 只声明"期望做什么",不直接修改图),同时也为调试和可复现性提供了良好的基础——每个 Pass 的输出都可以单独 dump 出来进行人工审查。

等价变换的典型场景包括:将 Transpose + Reshape 复合序列替换为单一的 Reshape(当维度重排与形状变换可以合并时),以及将两个相邻的 Cast 算子合并为最终的单一类型转换。

5.3 串联流程详解

以下展示了一个完整的常量折叠 + DCE 串联流程在 GE 中的工作过程:

初始图:
  ConstA = 2.0
  ConstB = 3.0
  ConstC = Parameter("input")
  Mul1 = Mul(ConstC, ConstA)   // 输入之一是常量
  Mul2 = Mul(Mul1, ConstB)     // 两个输入都依赖常量,可折叠
  Print = Print(Mul2)           // 调试节点,无下游引用

Pass 1 - 常量折叠:
  分析 Mul2:输入 Mul1 非常量(依赖 Parameter),ConstB 是常量
  但 GE 的折叠策略要求"所有直接输入都是常量"
  Mul2 无法直接折叠

Pass 2 - 常量传播:
  发现 ConstC 虽然是 Parameter 但在当前推理中是常量输入值 4.0
  传播后:Mul1 的输入全部确定为常量 → Mul1 结果为 8.0
  继续传播:Mul2 的输入全部确定为常量 → Mul2 结果为 24.0
  生成新常量节点 ConstFoldResult = 24.0

Pass 3 - 节点替换:
  将所有指向 Mul2 的边重定向到 ConstFoldResult
  删除 Mul1 和 Mul2 节点

Pass 4 - 死代码消除:
  Print 节点没有任何下游引用,且无副作用 → 删除
  ConstA、ConstB 若引用计数降为零 → 级联删除

最终图:
  ConstC = Parameter("input")
  ConstFoldResult = 24.0 (常量)
  输出 = ConstFoldResult

这个流程展示了各 Pass 之间的数据依赖关系:常量传播为折叠创造条件,折叠为 DCE 准备条件,而 DCE 又进一步清理传播过程中产生的临时节点。

六、与算子融合的协同

6.1 常量折叠为融合创造条件

算子融合(Operator Fusion)是将多个连续的小算子合并为单一算子执行的技术。融合后的算子在昇腾 NPU 上可以实现更高的数据局部性(无需将中间结果写回全局内存),并减少 kernel 启动开销。但融合决策需要满足一个前提条件:被融合的算子之间不存在"阻止融合的边界"。

常量折叠的作用在于消除这些边界。例如:

ConstA = 128
ConstB = 3
ExpandDims = ExpandDims(input, ConstA)
Reshape    = Reshape(ExpandDims, [batch, 128, 3, H, W])

如果 ExpandDims 的 axis 和 Reshape 的 target shape 都可以在编译期确定,则整个 ExpandDims + Reshape 序列可以被折叠为一个新的常量 Reshape 操作,从而消除两个算子之间的数据依赖边界。此时,一个原本需要分别调度两个 kernel 的序列,在融合阶段可以直接被 FuseIntoReshape 规则匹配并合并。

6.2 融合后的新常量折叠机会

融合本身也可能创造新的常量折叠机会。当多个小算子融合为一个大算子后,融合算子的输入如果变得更"纯粹"(减少了中间临时张量),则常量折叠规则可能在新的大算子上重新触发。GE 在融合 Pass 之后会重新执行一轮轻量级的常量折叠检查,以确保融合产生的复合算子不会被遗留的可折叠表达式拖累。

6.3 融合模式表与折叠决策的交互

GE 的融合引擎维护了一张融合模式表(Fusion Pattern Table),其中每条规则定义了输入算子序列 → 输出融合算子的映射。融合决策算法在遍历图时查找与模式表匹配的子图。在匹配过程中,如果子图中存在常量输入节点,GE 会优先尝试将这些常量在匹配前进行折叠,以避免常量节点干扰模式匹配的路径对齐。

七、两个关键陷阱与规避方案

7.1 陷阱一:常量折叠后精度变化

问题描述。当一个算子的所有输入都是常量时,GE 在编译期使用内部的数值计算引擎对该算子求值。这个求值过程可能使用不同于运行时精度要求的计算路径。例如,某个 MatMul 算子在昇腾 NPU 上以混合精度(float16 计算、float32 累加)执行,但如果在常量折叠阶段以纯 float16 进行求值,最终存储的常量精度可能会低于运行时的结果。

这种精度差异在大模型训练的后期阶段影响尤为明显:当 loss 已经收敛到很小的数值时,精度损失可能导致优化器状态出现不可忽视的偏差。

规避方案。GE 提供了精度模式(Precision Mode)配置选项,用户可以在图编译配置中指定常量折叠的数值计算使用与运行时一致的精度级别:

import ascend.ge as ge

graph_options = ge.GraphOptions()
# 设置常量折叠的数值精度模式为与算子运行时一致
graph_options.folding_precision_mode = "fp32"
ge.set_global_options(graph_options)

# 加载并编译模型
framework = ge.Framework("onnx")
graph = framework.load_model("model.onnx")
compiled = ge.compile_graph(graph)

同时,对于涉及归一化统计量(如 BatchNorm 的 running_mean/running_var)的常量折叠,建议用户显式地将这些算子标记为"不可折叠"(通过算子属性 can_fold=False),确保它们在运行时以正确的精度和状态执行。

7.2 陷阱二:副作用节点误删

问题描述。死代码消除阶段的核心假设是"如果一个节点没有被任何后续节点引用,则它不会影响最终输出"。然而,这个假设在存在副作用的算子上不成立。典型的误删场景包括:

  • 误将 UpdateWeight 类算子标记为死代码(因为在当前图中没有后续算子引用其输出,但该算子会原地修改 Parameter)
  • Print 节点误删(某些调试流水线要求保留 Print 以满足日志合规要求)
  • Dropout 的 mask 张量引用节点删掉(虽然 mask 不参与反向传播,但在训练模式下 Dropout 算子的副作用是必须保留的)

规避方案。GE 在 DCE Pass 中为每类副作用算子定义了明确的保留策略:

# 配置 DCE 阶段的保留策略
dce_options = ge.DCEOptions()
# 保留所有标记为 keep_side_effect 的算子
dce_options.preserve_side_effect_nodes = True
# 保留指定名称模式的节点(支持通配符)
dce_options.preserve_nodes_by_pattern = ["UpdateWeight*", "Print*", "Dropout*"]

# 应用 DCE 配置并执行图优化
optimized_graph = ge.run_dce(graph, dce_options)

此外,GE 在执行 DCE 之前会自动进行副作用分析(Side Effect Analysis):对于图中每个节点,如果它属于副作用算子白名单(该白名单可以通过配置文件扩展),则即使它的输出未被引用,节点也会被保留。这一机制确保了误删的兜底防护。

八、实战代码

8.1 代码块一:查看当前图的优化状态

import ascend.ge as ge

# 加载模型并查看原始图的节点数量
framework = ge.Framework("onnx")
graph = framework.load_model("resnet50.onnx")
print(f"原始图节点数: {graph.get_node_count()}")
print(f"原始图算子类型分布:")
for op_type, count in graph.get_op_type_distribution().items():
    print(f"  {op_type}: {count}")

8.2 代码块二:启用常量折叠优化

# 获取全局优化选项
opt_options = ge.OptimizerOptions()

# 显式启用常量折叠和死代码消除
opt_options.enable_constant_folding = True
opt_options.enable_algebraic_simplification = True
opt_options.enable_dead_code_elimination = True

# 设置常量折叠的最大迭代轮次(默认 5)
opt_options.max_folding_iterations = 8

# 应用优化选项
ge.set_global_options(opt_options)

# 重新编译图
compiled = ge.compile_graph(graph)
optimized_graph = compiled.get_optimized_graph()
print(f"优化后节点数: {optimized_graph.get_node_count()}")

8.3 代码块三:调试 Dump——查看常量折叠前后对比

import ascend.ge as ge

# 设置 Dump 选项:将常量折叠前后的图分别输出到指定目录
dump_options = ge.DumpOptions()
dump_options.dump_path = "./ge_debug_dump"
dump_options.dump_mode = "after_each_pass"
dump_options.dump_formats = ["pb", "txt"]  # 同时输出二进制和文本格式

# 启用 Dump
ge.enable_dump(dump_options)

# 执行编译(会在每个 Pass 后自动生成 dump 文件)
compiled = ge.compile_graph(graph)

# Dump 输出结构示例:
# ./ge_debug_dump/
#   resnet50_original.pb
#   resnet50_after_constant_folding.pb
#   resnet50_after_constant_folding.txt
#   resnet50_after_dce.pb
#   resnet50_after_dce.txt

8.4 代码块四:读取 Dump 文件并分析常量传播效果

import ascend.ge as ge

# 读取常量折叠后的 Dump 文件
folded_graph = ge.Graph.load_from_pb("./ge_debug_dump/resnet50_after_constant_folding.pb")

# 遍历图中所有常量节点
const_nodes = [n for n in folded_graph.get_all_nodes() if n.type == "Const"]
print(f"常量节点数量: {len(const_nodes)}")
for node in const_nodes:
    tensor = node.get_output_tensor(0)
    print(f"  节点名: {node.name}")
    print(f"  Shape: {tensor.shape}, Dtype: {tensor.dtype}")
    # 打印前 5 个元素(截断显示)
    if tensor.size < 6:
        print(f"  值: {tensor.to_list()}")
    else:
        print(f"  值(前5个): {tensor.to_list()[:5]} ...")

8.5 代码块五:等价变换示例——合并连续 Reshape

import ascend.ge as ge

# 查找图中的连续 Reshape 模式
fusion_pattern = ge.FusionPattern("reshape_chain")
fusion_pattern.add_rule(
    pattern_nodes=["Reshape", "Reshape"],
    fused_op="Reshape",
    condition=lambda n1, n2: n2.get_attr("shape").is_computable_from(n1)
)

# 注册自定义融合模式
ge.register_fusion_pattern(fusion_pattern)

# 应用融合并查看效果
fused_graph = ge.fuse_operators(graph)
print("融合前 Reshape 数量:", len([n for n in graph.get_all_nodes() if n.type == "Reshape"]))
print("融合后 Reshape 数量:", len([n for n in fused_graph.get_all_nodes() if n.type == "Reshape"]))

8.6 代码块六:控制常量折叠精度

import ascend.ge as ge

# 精细控制各类型的常量折叠行为
folding_config = {
    # MatMul 和 Conv 等核心算子使用 FP32 折叠
    "core_ops": {"precision": "fp32", "enabled": True},
    # Shape 相关算子不折叠(运行时形状依赖)
    "shape_ops": {"precision": "any", "enabled": False},
    # Dropout mask 使用原样保留
    "random_ops": {"precision": "any", "enabled": False},
}

# 通过 API 注入折叠配置
ge.configure_constant_folding(folding_config)

# 验证配置是否生效
cfg = ge.get_constant_folding_config()
print("MatMul 折叠精度:", cfg["core_ops"]["precision"])
print("Shape 算子折叠:", cfg["shape_ops"]["enabled"])

8.7 代码块七:查看 DCE 前后的引用计数变化

import ascend.ge as ge

# 在 DCE 前后分别输出引用计数分布
def print_refcount_stats(g, label):
    refcounts = {}
    for node in g.get_all_nodes():
        rc = node.get_ref_count()
        refcounts[rc] = refcounts.get(rc, 0) + 1
    print(f"\n[{label}] 引用计数分布:")
    for rc in sorted(refcounts.keys()):
        print(f"  refcount={rc}: {refcounts[rc]} 个节点")

original_graph = graph
print_refcount_stats(original_graph, "DCE 前")

# 执行 DCE
dced_graph = ge.run_dce(original_graph)
print_refcount_stats(dced_graph, "DCE 后")

8.8 代码块八:保留特定名称模式的节点(防止误删)

import ascend.ge as ge

dce_config = ge.DCEOptions()
# 通过正则表达式指定保留的节点名称模式
dce_config.preserve_patterns = [
    r"^.*update_weight.*$",   # 保留权重更新类节点
    r"^.*print_output.*$",    # 保留打印节点
    r"^.*dropout_mask.*$",   # 保留 Dropout mask
]

# 单独执行 DCE Pass
dced_graph = ge.run_dce(graph, dce_config)

# 验证保留效果
preserved = [n.name for n in dced_graph.get_all_nodes()
             if "update_weight" in n.name or "print_output" in n.name]
print(f"保留的节点: {preserved}")

8.9 代码块九:多轮迭代——观察折叠 + DCE 串联效果

import ascend.ge as ge

def optimize_iteratively(g, max_rounds=10):
    """执行多轮常量折叠 + DCE,直到图不再变化"""
    for round_idx in range(1, max_rounds + 1):
        prev_node_count = g.get_node_count()
        prev_const_count = len([n for n in g.get_all_nodes() if n.type == "Const"])

        g = ge.run_constant_folding(g)
        g = ge.run_dce(g)

        curr_node_count = g.get_node_count()
        curr_const_count = len([n for n in g.get_all_nodes() if n.type == "Const"])

        delta_nodes = curr_node_count - prev_node_count
        delta_consts = curr_const_count - prev_const_count
        print(f"轮次 {round_idx}: 节点变化 {delta_nodes:+d} (共 {curr_node_count}), "
              f"常量节点变化 {delta_consts:+d} (共 {curr_const_count})")

        if delta_nodes == 0:
            print("图已收敛,停止优化")
            break
    return g

final_graph = optimize_iteratively(graph)

8.10 代码块十:自定义常量折叠规则扩展

import ascend.ge as ge

# 定义自定义折叠规则:识别 Scale + Bias + ReLU 融合前的常量折叠
class FoldScaleBiasReLU(ge.FoldingRule):
    def match(self, node_chain):
        """
        匹配模式: Mul(Add(input, bias), scale) -> ReLU
        其中 scale 和 bias 都是常量
        """
        if len(node_chain) < 3:
            return None
        mul_node, add_node, relu_node = node_chain[-3:]
        if mul_node.type != "Mul" or add_node.type != "Add" or relu_node.type != "ReLU":
            return None
        if not (mul_node.inputs[1].is_constant() and add_node.inputs[1].is_constant()):
            return None
        return {"scale": mul_node.inputs[1].value, "bias": add_node.inputs[1].value}

    def apply(self, match_info):
        scale = match_info["scale"]
        bias = match_info["bias"]
        # 生成融合后的常量偏移(scale * input + bias 在编译期求值)
        fused_bias = scale * bias  # 折叠计算
        return ge.NodeBuilder("Add").input("input").input_const(fused_bias).build()

# 注册自定义规则
ge.register_folding_rule(FoldScaleBiasReLU())
print("自定义折叠规则已注册")

8.11 代码块十一:图编译完整流程(含优化配置)

import ascend.ge as ge

# ============ 第一步:加载模型 ============
framework = ge.Framework("onnx")
graph = framework.load_model("bert_base.onnx")
print(f"模型加载完成,原始节点数: {graph.get_node_count()}")

# ============ 第二步:配置优化 Pass ============
opt = ge.OptimizerOptions()
opt.enable_constant_folding = True
opt.enable_algebraic_simplification = True
opt.enable_constant_propagation = True
opt.enable_dead_code_elimination = True
opt.enable_node_substitution = True
opt.max_folding_iterations = 10
ge.set_global_options(opt)

# ============ 第三步:执行图优化 ============
optimized = ge.optimize_graph(graph)
print(f"图优化完成,优化后节点数: {optimized.get_node_count()}")

# ============ 第四步:编译为目标格式 ============
compile_opts = ge.CompileOptions()
compile_opts.target = "Ascend"          # 目标设备:昇腾 NPU
compile_opts.precision_mode = "mixed"   # 混合精度模式
compile_opts.compile_mode = "train"     # 训练模式

compiled = ge.compile_graph(optimized, compile_opts)
print(f"图编译完成,生成模型大小: {compiled.get_model_size() / 1024:.1f} KB")

# ============ 第五步:保存编译结果 ============
compiled.save("bert_base_optimized.om")
print("优化模型已保存为 bert_base_optimized.om")

8.12 代码块十二:通过日志级别查看优化详情

import ascend.ge as ge

# 设置日志级别为 DEBUG 以查看常量折叠和 DCE 的详细过程
ge.set_log_level("DEBUG")

# 设置日志输出到文件(方便事后分析)
ge.configure_logger(
    level="DEBUG",
    log_file="./ge_optimization.log",
    also_print_to_console=True
)

# 执行编译(DEBUG 日志会输出每个 Pass 的详细处理信息)
compiled = ge.compile_graph(graph)

DEBUG 日志中与常量折叠相关的内容示例:

[DEBUG] Pass: constant_folding_iteration_1
[DEBUG]   扫描到常量输入 Mul 节点: node_name=Mul_12, 可折叠
[DEBUG]   执行折叠: Mul_12(Const(2.0), Const(3.0)) -> Const(6.0)
[DEBUG]   替换引用: 原有 4 个引用全部重定向
[DEBUG]   删除旧节点: Mul_12
[DEBUG] Pass: dead_code_elimination
[DEBUG]   反向可达性分析: 标记了 156 个可达节点
[DEBUG]   发现死代码: Print_3 (无副作用,无下游引用)
[DEBUG]   级联删除: Const_5 (refcount 0)
[INFO] DCE 完成,删除了 7 个节点

8.13 代码块十三:验证优化后的图语义等价性

import ascend.ge as ge
import numpy as np

# 生成随机输入数据
np.random.seed(42)
input_data = {
    "input_ids": np.random.randint(0, 21128, size=(1, 128), dtype=np.int32),
    "attention_mask": np.ones((1, 128), dtype=np.int32),
}

# 分别运行原始图和优化后图,对比输出
original_model = ge.CompiledModel("bert_base.onnx")
optimized_model = ge.CompiledModel("bert_base_optimized.om")

original_output = original_model.execute(input_data)
optimized_output = optimized_model.execute(input_data)

# 检查输出的最大绝对误差
max_diff = np.max(np.abs(original_output["last_hidden_state"]
                          - optimized_output["last_hidden_state"]))
print(f"优化前后输出最大绝对误差: {max_diff:.2e}")

if max_diff < 1e-4:
    print("✅ 语义等价性验证通过")
else:
    print("⚠️ 误差超出预期,建议检查优化配置")

8.14 代码块十四:融合 Pass 与常量折叠协同调试

import ascend.ge as ge

# 查看融合前的图状态(常量折叠后、融合前)
pre_fusion = ge.get_pass_output(graph, "after_constant_propagation")
print("融合前图中可融合的小算子数量:")

# 统计相邻的 Mul-Add 对(常见的融合候选)
mul_add_count = 0
for node in pre_fusion.get_all_nodes():
    if node.type == "Mul":
        consumers = pre_fusion.get_consumers(node)
        for consumer in consumers:
            if consumer.type == "Add":
                mul_add_count += 1

print(f"  Mul+Add 候选对: {mul_add_count}")

# 执行融合 Pass
post_fusion = ge.fuse_operators(pre_fusion)
print(f"融合后节点数: {post_fusion.get_node_count()}")

# 再次执行常量折叠(融合可能产生新的折叠机会)
final = ge.run_constant_folding(post_fusion)
print(f"融合后再折叠节点数: {final.get_node_count()}")

8.15 代码块十五:通过 Python API 查看优化 Pass 执行顺序

import ascend.ge as ge

# 查询 GE 中当前生效的所有优化 Pass 及其顺序
pass_list = ge.get_optimization_pass_list()
print("当前图优化 Pass 执行顺序:")
for idx, pass_info in enumerate(pass_list, 1):
    enabled = "✅" if pass_info.enabled else "❌"
    print(f"  {idx}. {enabled} {pass_info.name}")
    print(f"      描述: {pass_info.description}")
    print(f"      依赖 Pass: {pass_info.depends_on or '无'}")

# 交互式修改 Pass 顺序(例如将 DCE 提前)
ge.reorder_passes(["constant_folding", "dce", "algebraic_simplification"])
print("\nPass 顺序已调整")

九、结尾

本文系统性地剖析了昇腾 CANN 中 GE 模块在图优化阶段所执行的常量折叠与死代码消除技术。从编译优化思想的基本原理出发,逐层深入到常量传播的数据流分析框架、死代码消除的引用计数机制,再到 GE 中两者的完整串联实现与算子融合的协同配合,最后通过十五段实战代码展示了从配置、调试到验证的全流程操作。

常量折叠与死代码消除虽然是最基础的图优化技术,但它们的效果直接决定了后续算子融合的质量上限——一个被冗余节点和死代码充实的图,融合引擎将难以找到清晰可靠的融合路径。因此,掌握这两类优化技术的原理与调优方法,是深入理解昇腾 CANN 图编译体系的重要一步。

如果希望进一步了解图编译阶段中更多 Pass(如算子融合、内存布局优化、图划分等)的工作原理,推荐继续阅读 GE 模块中图编译阶段的相关章节。同时,昇腾 CANN 的开源代码仓库位于 https://atomgit.com/cann/ge,其中包含了 GE 模块的完整实现,欢迎读者深入探索与贡献。

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐