请添加图片描述

前言

昇腾 CANN(Compute Architecture for Neural Networks)是华为面向昇腾 AI 处理器提供的一套异构计算架构,它处于算子库层与硬件驱动层之间,负责将上层模型图翻译为可在昇腾 NPU 上高效执行的计算指令。metadef 是昇腾 CANN 中负责算子元数据定义与图优化处理的核心里程碑模块,中文常译为"算子元数据定义框架"。它不仅仅是存放算子签名的容器,更是一套完整的推理引擎——在模型编译阶段,metadef 根据算子的输入 Shape 与属性(Attr)信息,自动计算出每个算子的输出 Shape,从而为后续的内存布局规划、图融合决策与代码生成奠定坚实的数据结构基础。没有可靠的 Shape 推导,所有依赖输出尺寸的优化都将无从下手。本文将系统拆解 metadef 中 Shape 推导引擎的架构设计、推导流程、核心算法与实战技巧,帮助读者真正理解这一底层基石的工作全貌。


一、Shape 推导要解决的核心问题

1.1 动态 Shape 的本质挑战

在深度学习框架中,Shape(张量形状)描述了一个多维数组在各个维度上的大小。静态 Shape(如 Tensor[3, 224, 224])在编译时完全已知,动态 Shape(如 Tensor[N, 224, 224] 中的 N 是运行时才确定的批次大小)则带来了编译器设计的根本性难题。Shape 推导引擎的核心职责,就是在输入 Shape 全部或部分已知的情况下,尽可能地推导出所有中间结果和最终输出的 Shape。

1.2 内存分配的连锁依赖

当 Shape 未知时,编译期无法确定张量的存储大小。如下所示,Conv 算子的输出 H、W 由输入尺寸、卷积核尺寸、步长、填充等属性共同决定:

output_H = floor((input_H + 2*pad - dilation_h*(kernel_h-1) - 1) / stride_h) + 1
output_W = floor((input_W + 2*pad - dilation_w*(kernel_w-1) - 1) / stride_w) + 1

如果编译器在 Shape 未知时就贸然按最大可能尺寸分配内存,会造成严重的显存浪费;如果按最小尺寸分配,又可能在运行时发生越界。一个成熟的 Shape 推导引擎需要在编译期精确计算每一个维度的上下界,为内存池管理器提供准确的空间规划依据。

1.3 算子编译的依赖链断裂

昇腾 CANN 的算子编译器 Ascend C 采用分离式编译策略:算子原语(OpDesc)定义接口,TIK(Tensor Iterator Kernel)风格的内核函数实现计算逻辑。Ascend C 编译器在生成内核代码时,需要在编译期确定每个缓冲区的大小和偏移量。Shape 信息缺失会直接导致内核函数中的静态内存声明无法完成,例如以下 Ascend C 代码中的 LocalTensor 分配:

// Ascend C 风格伪代码:编译期需要知道输出 Shape 才能分配本地张量
LocalTensor<float> output = tensorAlloc::allocate<DT_FLOAT>(outputSize);

outputSize 必须通过 Shape 推导获得,否则 Ascend C 编译器将报"维度未知"错误,编译链路在此断裂。

1.4 图优化的前提条件

现代编译器优化大量依赖 Shape 信息:算子融合(Fusion)需要判断相邻算子的输出/输入是否匹配;内存复用(Memory Reuse)需要精确计算每个张量的生命周期;自动并行(Auto Parallelism)需要根据 Shape 大小划分计算区间。如下图的融合优化场景中,Conv + BatchNorm 能否融合,取决于 Conv 的输出 Shape 是否与 BatchNorm 的期望输入 Shape 完全一致:

原始图: Conv -> BatchNorm -> Activation
融合图: Fuse(Conv+BatchNorm) -> Activation

如果 Shape 推导不准确,融合后的图在运行时可能产生维度不匹配错误,而这种错误往往在真正执行到该算子时才暴露,定位成本极高。


二、metadef Shape 推导引擎架构

2.1 整体架构概览

metadef 的 Shape 推导引擎采用插件化架构,核心组件可以划分为以下四层:

┌──────────────────────────────────────────────────────┐
│              用户侧注册层 (Operator Registry)         │
│  每个算子注册自己的 InferShapeFn 推导函数              │
├──────────────────────────────────────────────────────┤
│              推导调度层 (InferShape Dispatcher)        │
│  根据算子类型路由到对应的推导函数                       │
├──────────────────────────────────────────────────────┤
│              推导上下文 (InferShapeContext)            │
│  管理输入 Shape、属性、约束条件与推导结果缓存           │
├──────────────────────────────────────────────────────┤
│              规则执行层 (Shape Function Implementations)│
│  各个算子的具体 Shape 推导逻辑实现                      │
└──────────────────────────────────────────────────────┘

2.2 推导规则注册机制

metadef 使用统一的注册宏来声明算子的 Shape 推导函数。每一个算子(如 Conv2d、MatMul)在注册时需要同时提供两个关键函数:

  • InferShapeAndTypeFn:推导输出 Shape 与数据类型
  • InferFormatFn:推导输出数据排布格式(Format)

注册过程通过宏展开完成,注册信息存储在 metadef 的全局算子注册表中。以下是一个典型的算子注册结构:

// metadef 算子元数据注册伪代码
constexpr OpMetadata g_matmul_op = {
    .name = "MatMul",
    .input_names = {"x", "w", "b"},
    .output_names = {"y"},
    .infer_shape_fn = MatMulInferShape,    // Shape 推导函数入口
    .infer_format_fn = MatMulInferFormat,  // Format 推导函数入口
    .attr_defs = {
        AttrDef("transpose_x", Bool, false),
        AttrDef("transpose_y", Bool, false),
    },
};

这种注册机制的优势在于,算子开发者只需要实现业务逻辑(Shape 怎么算),而无需关心推导流程的调度与异常处理。metadef 框架在底层自动处理推导顺序、循环检测与缓存逻辑。

2.3 Shape 函数定义规范

每个算子的 Shape 推导函数必须遵循统一的接口签名。通常接受一个 InferShapeContext* 指针,返回 Status(成功或错误信息)。推导函数内部通过上下文对象读取输入 Shape 与属性,然后写入输出 Shape:

// Shape 推导函数的标准签名
Status MatMulInferShape(InferShapeContext* ctx) {
    // Step 1: 从上下文中获取输入 Shape
    Shape x_shape = ctx->GetInputShape("x");
    Shape w_shape = ctx->GetInputShape("w");

    // Step 2: 读取算子属性
    bool trans_x = ctx->GetAttr<bool>("transpose_x");
    bool trans_y = ctx->GetAttr<bool>("transpose_y");

    // Step 3: 根据 Shape 推导规则计算输出 Shape
    std::vector<int64_t> y_dims;
    // ... 推导逻辑(详见第四章)

    // Step 4: 将结果写回上下文
    ctx->SetOutputShape("y", Shape(y_dims));
    return Status::OK();
}

2.4 推导上下文的设计

InferShapeContext 是整个推导引擎的数据中枢,它封装了所有推导所需的状态信息:

// 推导上下文的内部结构(概念性定义)
struct InferShapeContext {
    // 输入相关
    std::vector<Shape> input_shapes_;     // 原始输入 Shape
    std::vector<Format> input_formats_;  // 输入数据格式

    // 属性相关
    OpAttrMap attrs_;                     // 算子属性名值对

    // 输出相关(推导结果写回目标)
    std::vector<Shape> output_shapes_;
    std::vector<Format> output_formats_;

    // 约束信息
    std::vector<ShapeConstraint> constraints_;  // 维度间约束关系

    // 缓存机制
    ShapeCache cache_;                  // 避免重复推导

    // 工具方法
    Shape GetInputShape(size_t idx);
    AttrValue GetAttr(const std::string& name);
    void SetOutputShape(size_t idx, const Shape& shape);
    void AddConstraint(const ShapeConstraint& c);
};

上下文还维护一个约束图(Constraint Graph),记录各维度之间的关系(如 dim[2] = dim[0] * dim[1] / 4),用于跨算子的联合推导。这是解决动态 Shape 中"间接依赖"问题的关键机制。

2.5 支持的 Shape 类型

metadef 的 Shape 推导引擎需要处理三种 Shape 类型:

Shape 类型 描述 示例 推导难度
完全静态 所有维度均为已知常数 [32, 128, 768] 最简单,编译期确定
半动态 部分维度依赖运行时变量 [N, 128, 768] 中等,需追踪变量符号
完全动态 所有维度均未知 [?, ?, ?] 最难,需要保守估计

对于半动态和完全动态 Shape,引擎会维护一套**符号化维度(Symbolic Dim)**系统,用符号名称(如 NMK)代替具体数值,在后续推导中追踪这些符号之间的数学关系。


三、推导过程的四个阶段

3.1 阶段一:输入 Shape 解析

第一阶段从解析算子的输入描述开始。metadef 从模型图的边(Edge)信息中提取每个输入张量的 Shape 与 Format。对于从常量(Const)或权重(Weight)节点传入的张量,其 Shape 在图构建期就已经固定,可以直接作为已知条件使用。

# Python 侧通过 ge_graph 接口设置输入 Shape 的示例
import acl

def setup_input_shape(graph, op_name, input_idx, shape):
    """为指定算子的某个输入设置已知的 Shape 信息"""
    input_desc = acl.create_tensor_desc()
    acl.set_tensor_shape(input_desc, shape)
    acl.set_tensor_format(input_desc, ACL_FORMAT_NCHW)
    print(f"[Phase 1] Input {input_idx} of op '{op_name}': {shape}")
    return input_desc

解析阶段还需要处理 Shape 中的动态维度标记。常见的标记方式包括:-1(由其他维度推导得出)、None(完全未知)和符号化维度(如 batch_size_0)。metadef 的解析器会将这些标记转换为内部的 SymbolicDim 对象,建立起推导的起点集合。

3.2 阶段二:约束传播

第二阶段是约束传播(Constraint Propagation)。这是动态 Shape 推导的核心环节。metadef 维护一个全局的约束图,每个节点代表一个维度(可能是具体数值,也可能是符号),每条边代表一个约束关系(如相等、大于、加法关系等)。

考虑以下模型片段:

Reshape: input[*, 768] -> output[N, H, W]
其中 N * H * W = 原始维度总量

Reshape 算子的约束传播器会接收 input[*, 768] 中的 * 是未知总量这一事实,并在约束图中建立如下关系:

constraint: dim_total == 768 * N * H * W
constraint: dim_total > 0

当后续算子(如 Conv)消费 Reshape 的输出时,约束信息会沿着数据流反向传播,形成一个约束网络。metadef 使用约束求解器(基于线性规划思想的简化版)来逐步削减每个符号维度的可能取值范围:

// 约束传播器的核心循环
void ConstraintPropagator::Propagate(const OpNode& op) {
    // 获取当前算子的输出约束
    auto output_constraints = op.infer_shape_fn()->GetConstraints();

    // 遍历所有下游消费者
    for (const OpNode& consumer : op.consumers()) {
        // 将输出约束反推到消费者的输入端
        for (const auto& constraint : output_constraints) {
            auto propagated = TranslateConstraint(constraint, consumer);
            consumer.MergeConstraint(propagated);  // 合并到消费者约束集
        }
    }

    // 如果约束集发生变化,触发新一轮传播
    if (constraints_changed_) {
        Propagate(consumer);  // 递归传播
    }
}

3.3 阶段三:输出 Shape 计算

在所有可用的输入 Shape 和约束条件收集完毕后,进入第三阶段:逐算子计算输出 Shape。每个算子的 Shape 函数根据其数学语义执行确定性计算。

ReduceSum 算子为例,其推导逻辑如下:

// ReduceSum 的 Shape 推导实现
Status ReduceSumInferShape(InferShapeContext* ctx) {
    auto input_shape = ctx->GetInputShape("x");
    auto axes_attr = ctx->GetAttr<std::vector<int64_t>>("axes");
    bool keep_dims = ctx->GetAttr<bool>("keep_dims");

    std::vector<int64_t> output_dims;

    if (keep_dims) {
        // keep_dims=true: 输出 Shape 与输入同维度,未缩减维度保持原值
        output_dims = input_shape.Dims();
        for (int64_t axis : axes_attr) {
            output_dims[axis] = 1;
        }
    } else {
        // keep_dims=false: 移除被求和的维度
        output_dims = input_shape.Dims();
        // 按降序移除轴以避免索引偏移
        std::vector<int64_t> sorted_axes = axes_attr;
        std::sort(sorted_axes.begin(), sorted_axes.end(), std::greater<int64_t>());
        for (int64_t axis : sorted_axes) {
            if (axis >= 0 && axis < static_cast<int64_t>(output_dims.size())) {
                output_dims.erase(output_dims.begin() + axis);
            }
        }
    }

    ctx->SetOutputShape("y", Shape(output_dims));
    return Status::OK();
}

3.4 阶段四:一致性校验

推导完成后,metadef 进入最后的一致性校验阶段(Consistency Verification)。这一阶段执行两类检查:

类型一:自洽性校验(Self-consistency Check)
验证算子的推导结果满足自身约束。例如,MatMul 的输出维度不能为负数;Reshape 的总元素数必须守恒:

// Reshape 的一致性校验
Status ReshapeConsistencyCheck(InferShapeContext* ctx) {
    auto input_shape = ctx->GetInputShape("x");
    auto output_shape = ctx->GetOutputShape("y");

    int64_t input_elements = input_shape.NumElements();
    int64_t output_elements = output_shape.NumElements();

    if (input_elements != output_elements) {
        return errors::InvalidArgument(
            "Reshape failed: input has ", input_elements,
            " elements, but output shape ", output_shape.DebugString(),
            " has ", output_elements, " elements. Elements must be preserved.");
    }
    return Status::OK();
}

类型二:跨算子一致性校验(Cross-op Consistency Check)
验证相邻算子之间的 Shape 匹配性。上一算子的输出 Shape 必须与下一算子的输入 Shape 完全兼容,包括维度数量和每个维度的具体值(已知维度必须相等,未知维度可接受自由匹配)。

如果某维度在两个相连的算子中被推导为不同的已知值(如前一个算子推导出该维度为 64,后一个算子要求为 32),则触发 Shape Conflict 错误,并附带完整的冲突路径信息。


四、常见算子的 Shape 推导规则详解

4.1 MatMul(矩阵乘法)

MatMul 的 Shape 推导规则是整个推导体系中最基础也最重要的参考。其数学语义决定了输出 Shape 的计算方式:

  • transpose_x = falsetranspose_y = false
    输出 Shape = [batch_dims..., M, K] 其中 K 来自 x 的最后一个维度,M 来自 y 的倒数第二个维度
  • 矩阵乘法要求前一个矩阵的列维等于后一个矩阵的行维,否则报错
// MatMul Shape 推导完整实现
Status MatMulInferShape(InferShapeContext* ctx) {
    auto x_shape = ctx->GetInputShape("x");
    auto w_shape = ctx->GetInputShape("w");

    bool trans_x = ctx->GetAttr<bool>("transpose_x");
    bool trans_y = ctx->GetAttr<bool>("transpose_y");

    const auto& x_dims = x_shape.Dims();
    const auto& w_dims = w_shape.Dims();

    // MatMul 至少是二维矩阵乘法
    if (x_dims.size() < 2 || w_dims.size() < 2) {
        return errors::InvalidArgument(
            "MatMul requires inputs to have at least 2 dimensions, "
            "got x: ", x_dims.size(), ", w: ", w_dims.size());
    }

    // 计算批次维度(broadcast 的维度)
    std::vector<int64_t> batch_dims;
    size_t x_batch_dims = x_dims.size() - 2;
    size_t w_batch_dims = w_dims.size() - 2;
    size_t max_batch = std::max(x_batch_dims, w_batch_dims);
    for (size_t i = 0; i < max_batch; ++i) {
        int64_t x_dim = (i < x_batch_dims) ? x_dims[i] : 1;
        int64_t w_dim = (i < w_batch_dims) ? w_dims[i] : 1;
        // 批次维度需要兼容
        if (x_dim != w_dim && x_dim != 1 && w_dim != 1) {
            return errors::InvalidArgument(
                "Batch dimension mismatch at index ", i,
                ": x_dim=", x_dim, ", w_dim=", w_dim);
        }
        batch_dims.push_back(std::max(x_dim, w_dim));
    }

    // 获取 K 维度(参与乘法的一侧)
    int64_t k_x = x_dims[x_dims.size() - (trans_x ? 1 : 2)];
    int64_t k_w = w_dims[w_dims.size() - (trans_y ? 2 : 1)];

    // K 维度必须相等或至少一个为动态
    if (!IsDynamic(k_x) && !IsDynamic(k_w) && k_x != k_w) {
        return errors::InvalidArgument(
            "MatMul dimension K mismatch: x has ", k_x,
            " (after transpose=", trans_x, "), w has ", k_w,
            " (after transpose=", trans_y, "). They must match.");
    }

    // 计算输出维度
    int64_t M = x_dims[x_dims.size() - (trans_x ? 2 : 1)];
    int64_t N = w_dims[w_dims.size() - (trans_y ? 1 : 2)];

    std::vector<int64_t> y_dims = batch_dims;
    y_dims.push_back(M);  // 保留维度(M)
    y_dims.push_back(N);  // 结果维度(N)

    ctx->SetOutputShape("y", Shape(y_dims));
    ctx->SetOutputType("y", ctx->GetInputType("x"));  // 继承数据类型
    return Status::OK();
}

4.2 Conv2d(卷积)

Conv2d 的 Shape 推导需要考虑输入 Shape、卷积核属性(kernel size、stride、padding、dilation)以及 groups 参数:

// Conv2d Shape 推导实现
Status Conv2dInferShape(InferShapeContext* ctx) {
    auto input_shape = ctx->GetInputShape("x");   // [N, C_in, H, W]
    auto filter_shape = ctx->GetInputShape("w");  // [C_out, C_in/g, kH, kW]

    int64_t stride_h = ctx->GetAttr<int64_t>("stride_h");
    int64_t stride_w = ctx->GetAttr<int64_t>("stride_w");
    int64_t pad_h = ctx->GetAttr<int64_t>("pad_h");
    int64_t pad_w = ctx->GetAttr<int64_t>("pad_w");
    int64_t dilation_h = ctx->GetAttr<int64_t>("dilation_h");
    int64_t dilation_w = ctx->GetAttr<int64_t>("dilation_w");
    int64_t groups = ctx->GetAttr<int64_t>("groups");

    const auto& dims = input_shape.Dims();  // [N, C_in, H, W]

    int64_t N  = dims[0];
    int64_t C_in = dims[1];
    int64_t H_in = dims[2];
    int64_t W_in = dims[3];

    // 计算输出空间维度
    // H_out = floor((H_in + 2*pad_h - dilation_h*(kH-1) - 1) / stride_h) + 1
    auto CalcSpatialDim = [](int64_t in_dim, int64_t pad, int64_t dilation,
                              int64_t kernel, int64_t stride) -> int64_t {
        if (IsDynamic(in_dim)) return MakeSymbolicDim();
        int64_t effective_kernel = dilation * (kernel - 1) + 1;
        return (in_dim + 2 * pad - effective_kernel) / stride + 1;
    };

    int64_t H_out = CalcSpatialDim(H_in, pad_h, dilation_h,
                                    filter_shape.Dims()[2], stride_h);
    int64_t W_out = CalcSpatialDim(W_in, pad_w, dilation_w,
                                    filter_shape.Dims()[3], stride_w);

    // 输出通道数
    int64_t C_out = filter_shape.Dims()[0];

    // 校验 groups 与通道数的匹配
    if (C_in % groups != 0 || C_out % groups != 0) {
        return errors::InvalidArgument("Conv2d groups mismatch: "
            "C_in=", C_in, ", groups=", groups, ", C_out=", C_out);
    }

    ctx->SetOutputShape("y", Shape({N, C_out, H_out, W_out}));
    return Status::OK();
}

4.3 Transpose(维度重排)

Transpose 算子根据 perm(维度排列)参数对输入 Shape 的维度进行重排。其推导规则极为简洁:输出 Shape 的第 i 个维度等于输入 Shape 在 perm 位置上的维度

// Transpose Shape 推导实现
Status TransposeInferShape(InferShapeContext* ctx) {
    auto input_shape = ctx->GetInputShape("x");
    auto perm = ctx->GetAttr<std::vector<int64_t>>("perm");

    const auto& input_dims = input_shape.Dims();
    int64_t rank = input_dims.size();

    if (static_cast<int64_t>(perm.size()) != rank) {
        return errors::InvalidArgument(
            "Transpose perm size (", perm.size(),
            ") must match input rank (", rank, ")");
    }

    // 验证 perm 是 [0, rank) 的一个排列
    std::vector<bool> seen(rank, false);
    for (int64_t p : perm) {
        if (p < 0 || p >= rank) {
            return errors::InvalidArgument("Invalid perm value: ", p,
                ". Must be in [0, ", rank, ")");
        }
        seen[p] = true;
    }

    std::vector<int64_t> output_dims(rank);
    for (int64_t i = 0; i < rank; ++i) {
        output_dims[i] = input_dims[perm[i]];
    }

    ctx->SetOutputShape("y", Shape(output_dims));
    return Status::OK();
}

4.4 Reshape(形状变换)

Reshape 的推导规则遵循元素守恒律:输出张量的总元素数必须等于输入张量的总元素数。当某个维度设为 -1 时,metadef 会自动用总元素数除以其他已知维度的乘积来推断该维度:

// Reshape Shape 推导实现
Status ReshapeInferShape(InferShapeContext* ctx) {
    auto input_shape = ctx->GetInputShape("x");
    auto new_shape_attr = ctx->GetAttr<std::vector<int64_t>>("shape");

    const auto& input_dims = input_shape.Dims();
    int64_t total_elements = input_shape.NumElements();

    std::vector<int64_t> output_dims = new_shape_attr;

    // 统计已知维度乘积和 -1 的位置
    int64_t known_product = 1;
    int64_t minus_one_idx = -1;
    for (int64_t i = 0; i < static_cast<int64_t>(output_dims.size()); ++i) {
        if (output_dims[i] == -1) {
            if (minus_one_idx != -1) {
                return errors::InvalidArgument(
                    "Reshape can have at most one -1 dimension. Found at least two.");
            }
            minus_one_idx = i;
        } else if (output_dims[i] >= 0) {
            known_product *= output_dims[i];
        } else {
            return errors::InvalidArgument(
                "Invalid dimension value: ", output_dims[i],
                ". Must be >= 0 or -1.");
        }
    }

    // 自动推导 -1 维度
    if (minus_one_idx != -1) {
        if (known_product == 0) {
            return errors::InvalidArgument(
                "Cannot infer -1 dimension when total elements is 0 or "
                "known product is 0.");
        }
        if (total_elements % known_product != 0) {
            return errors::InvalidArgument(
                "Reshape cannot reshape ", total_elements,
                " elements into shape with known product ", known_product,
                " (remaining dimension -1). They are not divisible.");
        }
        output_dims[minus_one_idx] = total_elements / known_product;
    }

    // 元素守恒校验
    int64_t output_elements = 1;
    for (int64_t d : output_dims) {
        output_elements *= d;
    }
    if (output_elements != total_elements) {
        return errors::InvalidArgument(
            "Reshape failed element count check: input has ", total_elements,
            " elements, output has ", output_elements, " elements.");
    }

    ctx->SetOutputShape("y", Shape(output_dims));
    return Status::OK();
}

五、推导失败的 diagnosing 机制

5.1 错误信息的层次化设计

昇腾 CANN 的 metadef 推导引擎采用了三级错误报告机制,确保错误信息既有足够的上下文,又不会因信息过载而难以阅读:

层级一:位置信息(Where)
指出错误发生在哪个算子节点、哪条边上,格式为 NodeName[input/output]:index

[ERROR] Shape inference failed at node: conv_block_3, output: 0

层级二:根因信息(Why)
解释推导失败的具体原因,尽可能给出数学层面的解释:

[ERROR] Reason: Dimension mismatch on axis 1.
         Expected: 256 (inferred from previous node's output)
         Actual: 128 (required by this node's attribute 'channels')

层级三:建议信息(How to fix)
给出用户侧可操作的修复建议:

[ERROR] Suggestion: Check the 'channels' attribute in conv_block_3 node,
         or verify the output shape of its upstream node.
         Set environment GE_SESSION_GRAPH_DUMP=1 for full trace.

5.2 用户侧调试技巧

当 Shape 推导失败时,以下几个调试手段可以帮助快速定位问题:

# 技巧一:开启推导过程的详细日志
import os
os.environ["GE_INFER_SHAPE_VERBOSE"] = "1"
# 重新运行模型加载,日志中会打印每个算子的输入 Shape、属性和推导结果

# 技巧二:使用 ge_graph 的 dump 接口导出中间结果
def dump_inference_trace(model_path, output_dir):
    """
    导出 Shape 推导的完整轨迹,用于事后分析
    """
    import subprocess
    cmd = [
        "atc", "--model", model_path,
        "--output", f"{output_dir}/ir",
        "--enable_inference_shape_dump", "true",
        "--inference_shape_dump_path", f"{output_dir}/shape_trace.json",
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print("STDOUT:", result.stdout)
        print("STDERR:", result.stderr)
    return result.returncode

# 技巧三:单节点测试——独立验证某个算子的 Shape 推导
def test_single_op_infer(op_type, input_shapes, attrs):
    """
    单独测试某个算子的 Shape 推导逻辑,
    绕过整图推导的复杂性
    """
    from hccl.infer.api import OpInferShapeTest
    tester = OpInferShapeTest(op_type)
    for idx, shape in enumerate(input_shapes):
        tester.SetInputShape(idx, shape)
    for key, value in attrs.items():
        tester.SetAttr(key, value)
    ok, output_shapes = tester.Infer()
    print(f"Op: {op_type}")
    print(f"Input Shapes: {input_shapes}")
    print(f"Attrs: {attrs}")
    print(f"Output Shapes: {output_shapes}")
    print(f"Status: {'OK' if ok else 'FAILED'}")
    return ok, output_shapes

5.3 推导冲突的溯源

当多个算子的约束相互冲突时,metadef 会构建一个冲突传播链,从发生冲突的两个节点出发,逆向追溯到它们最近的共同祖先(Least Common Ancestor,LCA),找出最早产生不一致约束的节点:

// 冲突溯源算法的核心实现
std::vector<std::string> TraceShapeConflict(
    const OpNode& node_a,
    const OpNode& node_b,
    const ShapeValue& conflict_dim) {

    std::vector<std::string> trace_a, trace_b;

    // 反向追踪 node_a 的推导路径
    OpNode* cur = &node_a;
    while (cur != nullptr) {
        trace_a.push_back(cur->name() + ":dim[" +
            std::to_string(conflict_dim) + "] = " +
            cur->GetInferredDim(conflict_dim).DebugString());
        cur = cur->producer();  // 沿反向数据流移动
    }

    // 反向追踪 node_b 的推导路径
    cur = &node_b;
    while (cur != nullptr) {
        trace_b.push_back(cur->name() + ":dim[" +
            std::to_string(conflict_dim) + "] = " +
            cur->GetInferredDim(conflict_dim).DebugString());
        cur = cur->producer();
    }

    // 合并路径,返回完整冲突链
    std::vector<std::string> full_trace;
    full_trace.insert(full_trace.end(), trace_a.rbegin(), trace_a.rend());
    full_trace.insert(full_trace.end(), trace_b.begin(), trace_b.end());
    return full_trace;
}

打印出的冲突链示例如下,用户可以清晰地看到维度值在哪个节点开始产生分歧:

Shape conflict trace for dimension 2:
  [1] input_placeholder: dim[2] = N (unknown, symbolic)
  [2] embed_layer/Lookup: dim[2] = 768 (fixed from embedding table)
  [3] reshape_1: dim[2] = 384 (computed as 768/2)
  [4] conv_block_0: dim[2] = 192 (computed as 384/2)  ← value assigned here
  [5] conv_block_1: dim[2] = 96  (computed as 192/2)  ← value assigned here
  [6] conv_block_2: REQUIRES dim[2] = 128 (from attr 'out_channels')
                       ~~~~~~~~~~ CONFLICT: 96 != 128

六、两个关键陷阱与解决方案

陷阱一:推导规则循环依赖导致栈溢出

问题描述

在复杂图的推导过程中,如果两个或多个算子的 Shape 推导函数相互依赖对方的输出作为输入,且形成了一个闭环,就可能导致无限递归。典型场景出现在需要反向约束传播的情况下:

节点 A 的输出 Shape → 决定 → 节点 B 的输入约束
节点 B 的输出 Shape → 决定 → 节点 A 的输入约束(形成了闭环)

当 metadef 尝试推导节点 A 时,发现需要先知道 B 的某些信息;推导 B 时又需要 A 的信息,形成如下所示的死循环:

InferShape(A) → needs B's output → InferShape(B) →
needs A's output → InferShape(A) → needs B's output → ... (stack overflow!)

解决方案

metadef 框架实现了**三色标记法(Three-Color Marking)**来检测和打破循环依赖:

// 三色标记状态机
enum class NodeState {
    WHITE = 0,  // 未访问
    GRAY   = 1,  // 正在访问(推导中)
    BLACK  = 2,  // 推导完成
};

class InferShapeScheduler {
    std::unordered_map<std::string, NodeState> node_states_;
    std::unordered_map<std::string, int>        visit_count_;

    static constexpr int MAX_REVISIT_COUNT = 16;  // 兜底保护阈值

    Status InferNode(OpNode* node) {
        const auto& name = node->name();

        if (node_states_[name] == NodeState::BLACK) {
            return Status::OK();  // 已推导完成,直接返回
        }

        if (node_states_[name] == NodeState::GRAY) {
            // 检测到循环依赖!进入保守模式
            return HandleCycleDependency(node);
        }

        // 标记为 GRAY(正在推导中)
        node_states_[name] = NodeState::GRAY;

        // 尝试推导(可能因为等待依赖而返回 Retry 状态)
        auto status = node->RunInferShape();

        if (status.IsRetry()) {
            visit_count_[name]++;
            if (visit_count_[name] > MAX_REVISIT_COUNT) {
                return errors::FailedPrecondition(
                    "Shape inference stuck in potential cycle for node '", name,
                    "'. Visit count exceeded ", MAX_REVISIT_COUNT,
                    ". Check for circular dependencies in your model graph.");
            }
            // 保留 GRAY 标记(不降回 WHITE),允许下一轮重试
            return status;
        }

        // 推导成功,标记为 BLACK
        node_states_[name] = NodeState::BLACK;
        return Status::OK();
    }

    Status HandleCycleDependency(OpNode* node) {
        // 策略:对循环中的节点使用保守 Shape 估计
        // 即每个维度取最大可能值,确保内存分配不会不足
        std::vector<int64_t> conservative_dims;
        for (const auto& constraint : node->GetDimConstraints()) {
            conservative_dims.push_back(constraint.Max());
        }
        node->SetOutputShape(Shape(conservative_dims));
        return Status::OK();
    }
};

关键设计原则:永远不会将一个 GRAY 状态的节点降回 WHITE。如果一个节点被重新请求推导(因为它的某个依赖尚未就绪),visit_count 递增;当 visit_count 超过阈值时,判定为循环依赖,切换到保守推导模式,避免真正的栈溢出。

陷阱二:动态维度推导过早固化

问题描述

在图的拓扑排序推导过程中,如果某个动态维度在第一次遇到时就被赋予了一个具体值(即使后续有更充分的信息可以推导出更精确的值),就会导致**过早固化(Early Solidification)**问题。

举例说明:假设一个多分支融合的场景,一个中间层的输出被三个后续节点消费,其中两个节点的约束条件可以进一步缩小某个动态维度的取值范围,但由于第三个节点先完成了推导并"固化"了该维度值,另外两个节点的有效约束就无法再生效:

中间张量 T 的维度 D: 初始为 [N, H, W](N 未知)
  ├── 分支 A: N = 4(从自身约束推导得出)
  ├── 分支 B: N = 8(从自身约束推导得出)
  └── 分支 C: N = 16(从自身约束推导得出)
  
问题:如果分支 C 先完成推导并将 N 固化为 16,
     分支 A 和 B 的更精确约束就无法被接受,导致内存分配过大。

解决方案

metadef 引入了**区间追踪(Interval Tracking)**机制,不将每个维度固化为单一值,而是维护一个 [lower_bound, upper_bound] 区间。只有当区间收缩到只有一个可能值时,才真正固化:

// 动态维度的区间追踪器
class SymbolicDimTracker {
    // 符号维度区间结构
    struct DimInterval {
        int64_t lower;   // 下界
        int64_t upper;   // 上界
        bool is_fixed;   // 是否已固化(上下界相等)

        DimInterval(): lower(1), upper(INT64_MAX), is_fixed(false) {}

        void Narrow(int64_t new_lower, int64_t new_upper) {
            if (new_lower > lower) lower = new_lower;
            if (new_upper < upper) upper = new_upper;
            if (lower == upper) is_fixed = true;
        }
    };

    std::unordered_map<std::string, DimInterval> dim_intervals_;

public:
    void RegisterDim(const std::string& dim_id) {
        dim_intervals_[dim_id] = DimInterval();
    }

    // 每次收到新的约束时,尝试缩小区间
    Status ApplyConstraint(const std::string& dim_id,
                            const DimConstraint& constraint) {
        auto it = dim_intervals_.find(dim_id);
        if (it == dim_intervals_.end()) {
            return errors::NotFound("Symbolic dim '", dim_id, "' not found.");
        }

        DimInterval& interval = it->second;
        int64_t old_lower = interval.lower;
        int64_t old_upper = interval.upper;

        // 尝试应用约束来缩小区间
        constraint.ApplyTo(interval.lower, interval.upper);

        if (interval.lower > interval.upper) {
            return errors::InvalidArgument(
                "Conflicting constraints on dimension '", dim_id, "': "
                "constraint requires lower > upper. "
                "Previous interval: [", old_lower, ", ", old_upper, "]");
        }

        interval.is_fixed = (interval.lower == interval.upper);

        if (!interval.is_fixed) {
            // 未固化时,触发通知给所有等待该维度固化的节点
            NotifyConsumers(dim_id, interval.lower, interval.upper);
        }

        return Status::OK();
    }

    // 获取当前区间的字符串表示
    std::string DebugString(const std::string& dim_id) {
        const auto& interval = dim_intervals_[dim_id];
        if (interval.is_fixed) {
            return std::to_string(interval.lower);
        }
        return "[" + std::to_string(interval.lower) +
               ", " + std::to_string(interval.upper) + "]";
    }
};

使用这套机制后,即使某个分支先推导并写入了输出 Shape,metadef 也会区分"已固化的具体值"和"尚未固化的区间值"。后续的推导请求会继续接收新的约束信息,不断缩小区间,直到收敛到唯一解。如果直到最终编译时区间仍未收敛,metadef 会报告"动态维度无法完全确定,建议使用更保守的内存规划策略",而不是强制固化一个可能不准确的值。


七、实战代码

代码一:注册自定义算子的 Shape 推导函数

// my_custom_op_infer.cpp
// 为自定义算子注册 Shape 推导函数到 metadef 注册表

#include "graph/operator.h"
#include "graph/op_proto.h"
#include "graph/infer_shape_registry.h"

// 步骤 1: 定义推导函数
Status CustomLayerInferShape(InferShapeContext* ctx) {
    // 获取两个输入的 Shape
    auto x_shape = ctx->GetInputShape("x");
    auto weight_shape = ctx->GetInputShape("weight");

    // 读取自定义属性
    int64_t expand_ratio = ctx->GetAttr<int64_t>("expand_ratio");
    bool use_bias = ctx->GetAttr<bool>("use_bias");

    const auto& x_dims = x_shape.Dims();
    int64_t batch = x_dims[0];
    int64_t channels = x_dims[1];
    int64_t height = x_dims[2];
    int64_t width = x_dims[3];

    // 中间层通道数 = 输入通道 * expand_ratio
    int64_t mid_channels = channels * expand_ratio;

    std::vector<Shape> outputs;
    // 第一个输出:中间特征图
    outputs.push_back(Shape({batch, mid_channels, height, width}));
    // 第二个输出: gating mask(与输入同 Shape)
    outputs.push_back(Shape({batch, channels, height, width}));

    ctx->SetOutputs(outputs);
    return Status::OK();
}

// 步骤 2: 声明属性定义
static std::vector<AttrDef> BuildAttrDefs() {
    return {
        AttrDef("expand_ratio", Int64, 4),
        AttrDef("use_bias", Bool, false),
        AttrDef("activation", String, "gelu"),
    };
}

// 步骤 3: 将推导函数注册到全局注册表
static bool RegisterCustomLayerInferShape() {
    static InferShapeRegistry reg("CustomLayer");
    reg.SetInferShapeFn(CustomLayerInferShape);
    reg.SetAttrDefs(BuildAttrDefs());
    return reg.Register();
}

// 编译器会在全局对象构造阶段自动调用此注册函数
static bool registered = RegisterCustomLayerInferShape();

代码二:自定义算子的 Shape 推导单元测试

// my_custom_op_infer_test.cpp
// 为自定义算子的 Shape 推导函数编写单元测试

#include <gtest/gtest.h>
#include "graph/infer_shape_context.h"
#include "my_custom_op_infer.cpp"  // 引入推导函数

class CustomLayerInferShapeTest : public ::testing::Test {
protected:
    void SetUp() override {
        ctx_ = new InferShapeContext();
    }

    void TearDown() override {
        delete ctx_;
    }

    InferShapeContext* ctx_;
};

// 测试用例 1: 正常静态 Shape
TEST_F(CustomLayerInferShapeTest, StaticShape) {
    ctx_->SetInputShape("x", Shape({2, 64, 32, 32}));
    ctx_->SetInputShape("weight", Shape({256, 64, 1, 1}));
    ctx_->SetAttr("expand_ratio", static_cast<int64_t>(4));
    ctx_->SetAttr("use_bias", false);

    Status status = CustomLayerInferShape(ctx_);

    ASSERT_TRUE(status.IsOk()) << status.ErrorMsg();

    auto outputs = ctx_->GetOutputs();
    ASSERT_EQ(outputs.size(), 2);
    EXPECT_EQ(outputs[0].Dims(), std::vector<int64_t>({2, 256, 32, 32}));
    EXPECT_EQ(outputs[1].Dims(), std::vector<int64_t>({2, 64, 32, 32}));
}

// 测试用例 2: 动态批次维度
TEST_F(CustomLayerInferShapeTest, DynamicBatchDim) {
    ctx_->SetInputShape("x", Shape({SymbolicDim("N"), 64, 32, 32}));
    ctx_->SetInputShape("weight", Shape({256, 64, 1, 1}));
    ctx_->SetAttr("expand_ratio", static_cast<int64_t>(4));
    ctx_->SetAttr("use_bias", false);

    Status status = CustomLayerInferShape(ctx_);

    ASSERT_TRUE(status.IsOk());

    auto outputs = ctx_->GetOutputs();
    // N 应该被保留传播,不被修改
    EXPECT_TRUE(outputs[0].Dim(0).IsSymbolic());
    EXPECT_EQ(outputs[0].Dim(0).SymbolName(), "N");
}

// 测试用例 3: 属性验证
TEST_F(CustomLayerInferShapeTest, InvalidExpandRatio) {
    ctx_->SetInputShape("x", Shape({2, 64, 32, 32}));
    ctx_->SetInputShape("weight", Shape({0, 64, 1, 1}));  // 无效权重 Shape
    ctx_->SetAttr("expand_ratio", static_cast<int64_t>(0));

    Status status = CustomLayerInferShape(ctx_);

    EXPECT_FALSE(status.IsOk());
    EXPECT_TRUE(status.ErrorMsg().find("expand_ratio") != std::string::npos);
}

代码三:推导调试脚本

# infer_shape_debug.py
# 用于分析和调试 Shape 推导过程的辅助脚本

import json
import os
import subprocess
from typing import Dict, List, Any

class ShapeInferDebugger:
    def __init__(self, om_model_path: str):
        self.om_model_path = om_model_path
        self.infer_trace: List[Dict[str, Any]] = []

    def run_with_dump(self, output_dir: str) -> str:
        """运行模型并导出推导轨迹"""
        os.makedirs(output_dir, exist_ok=True)
        trace_path = os.path.join(output_dir, "shape_trace.json")

        env = os.environ.copy()
        env["GE_INFER_SHAPE_VERBOSE"] = "3"
        env["GE_SHAPE_DUMP_PATH"] = trace_path

        cmd = [
            "atc",
            "--model", self.om_model_path,
            "--output", os.path.join(output_dir, "model"),
            "--framework", "5",  # ONNX
            "--soc_version", "Ascend910",
        ]

        result = subprocess.run(cmd, env=env, capture_output=True, text=True)
        return result.stdout + "\n" + result.stderr

    def parse_trace(self, trace_path: str) -> None:
        """解析推导轨迹文件"""
        with open(trace_path, "r") as f:
            self.infer_trace = json.load(f)

    def find_bottleneck(self) -> List[str]:
        """找出推导耗时最长的算子"""
        if not self.infer_trace:
            return []

        timings = []
        for entry in self.infer_trace:
            node_name = entry.get("node", "unknown")
            time_ms = entry.get("infer_time_ms", 0)
            timings.append((node_name, time_ms))

        timings.sort(key=lambda x: x[1], reverse=True)
        return [f"  {name}: {time_ms:.2f}ms" for name, time_ms in timings[:10]]

    def check_dim_consistency(self, node_name: str) -> List[str]:
        """检查指定节点的维度一致性"""
        issues = []
        for entry in self.infer_trace:
            if entry.get("node") != node_name:
                continue
            inputs = entry.get("inputs", [])
            outputs = entry.get("outputs", [])
            for i, out_shape in enumerate(outputs):
                # 简单检查:元素数守恒
                in_elements = sum(self._prod(dims) for dims in inputs)
                out_elements = self._prod(out_shape)
                if in_elements != out_elements:
                    issues.append(
                        f"Node '{node_name}' output[{i}]: "
                        f"element count mismatch: in={in_elements}, out={out_elements}"
                    )
        return issues

    @staticmethod
    def _prod(dims: List[int]) -> int:
        result = 1
        for d in dims:
            if d <= 0:
                return 0  # 包含未知维度时返回 0
            result *= d
        return result

    def generate_report(self) -> str:
        """生成调试报告"""
        report = ["=" * 60]
        report.append("Shape Inference Debug Report")
        report.append("=" * 60)

        report.append("\n[TIMING] Top 10 slowest nodes:")
        for line in self.find_bottleneck():
            report.append(line)

        report.append("\n[CONSISTENCY] Dim consistency issues:")
        all_issues = []
        for entry in self.infer_trace:
            issues = self.check_dim_consistency(entry.get("node", ""))
            all_issues.extend(issues)
        if all_issues:
            for issue in all_issues:
                report.append(f"  ERROR: {issue}")
        else:
            report.append("  No consistency issues found.")

        return "\n".join(report)


# 使用示例
if __name__ == "__main__":
    debugger = ShapeInferDebugger("/path/to/your/model.onnx")
    output = debugger.run_with_dump("/tmp/shape_debug")
    print(output)

    trace_path = "/tmp/shape_debug/shape_trace.json"
    if os.path.exists(trace_path):
        debugger.parse_trace(trace_path)
        print(debugger.generate_report())

代码四:Shape 一致性自动化检查

# shape_consistency_checker.py
# 对整图进行 Shape 一致性自动化检查

from collections import defaultdict
from typing import Dict, List, Tuple, Optional

class ShapeConsistencyChecker:
    """
    检查模型图中所有相邻算子之间的 Shape 一致性。
    使用方法:实例化后调用 check_graph(ge_graph),返回所有不一致的位置。
    """

    def __init__(self):
        self.errors: List[Dict] = []

    def check_graph(self, graph) -> List[Dict]:
        """
        遍历图中的所有边(Edge),检查生产者的输出 Shape
        与消费者的输入 Shape 是否一致。
        """
        self.errors = []

        for edge in graph.edges():
            producer = edge.producer()
            consumer = edge.consumer()
            output_idx = edge.producer_output_index()
            input_idx = edge.consumer_input_index()

            producer_shape = producer.output_shape(output_idx)
            consumer_shape = consumer.input_shape(input_idx)

            mismatch = self._compare_shapes(
                producer_shape, consumer_shape,
                producer.name(), consumer.name(),
                output_idx, input_idx
            )

            if mismatch:
                self.errors.append(mismatch)

        return self.errors

    def _compare_shapes(
        self,
        shape_a: 'Shape',
        shape_b: 'Shape',
        node_a: str,
        node_b: str,
        idx_a: int,
        idx_b: int
    ) -> Optional[Dict]:

        dims_a = shape_a.dims()  # e.g., [N, 256, 7, 7]
        dims_b = shape_b.dims()  # e.g., [N, 256, 7, 7]

        if len(dims_a) != len(dims_b):
            return {
                "severity": "ERROR",
                "type": "RANK_MISMATCH",
                "producer": f"{node_a}:output[{idx_a}]",
                "consumer": f"{node_b}:input[{idx_b}]",
                "producer_shape": str(dims_a),
                "consumer_shape": str(dims_b),
                "message": f"Rank mismatch: producer has {len(dims_a)} dims, "
                           f"consumer expects {len(dims_b)} dims",
            }

        for axis, (d_a, d_b) in enumerate(zip(dims_a, dims_b)):
            if not self._dims_compatible(d_a, d_b):
                return {
                    "severity": "ERROR",
                    "type": "DIM_MISMATCH",
                    "producer": f"{node_a}:output[{idx_a}]",
                    "consumer": f"{node_b}:input[{idx_b}]",
                    "axis": axis,
                    "producer_dim": str(d_a),
                    "consumer_dim": str(d_b),
                    "message": f"Axis {axis} mismatch: producer has {d_a}, "
                               f"consumer expects {d_b}",
                }

        return None

    def _dims_compatible(self, dim_a, dim_b):
        """
        判断两个维度是否兼容。
        如果两者都是具体数值,必须相等。
        如果其中之一是符号维度(unknown),则兼容。
        """
        a_is_dynamic = isinstance(dim_a, str) or dim_a < 0
        b_is_dynamic = isinstance(dim_b, str) or dim_b < 0

        if a_is_dynamic or b_is_dynamic:
            return True  # 至少一个动态,保守认为兼容
        return dim_a == dim_b

    def print_report(self) -> None:
        """打印格式化的检查报告"""
        print(f"\n{'=' * 60}")
        print(f"Shape Consistency Check Report")
        print(f"Total errors found: {len(self.errors)}")
        print(f"{'=' * 60}\n")

        if not self.errors:
            print("All shapes are consistent. No errors detected.")
            return

        for i, err in enumerate(self.errors, 1):
            print(f"[{i}] {err['severity']}{err['type']}")
            print(f"    Producer: {err['producer']}  Shape: {err.get('producer_shape', err.get('producer_dim'))}")
            print(f"    Consumer: {err['consumer']}  Shape: {err.get('consumer_shape', err.get('consumer_dim'))}")
            print(f"    {err['message']}")
            print()


# 使用示例
def validate_model_shape(model_path: str):
    import acl
    graph = acl.load_model(model_path).get_graph()
    checker = ShapeConsistencyChecker()
    errors = checker.check_graph(graph)
    checker.print_report()
    return len(errors) == 0

代码五:Ascend C 中的 Shape 感知内存分配

// ascend_c_shape_alloc.cpp
// 在 Ascend C 算子实现中使用 Shape 推导结果进行内存分配

#include "graph/tensor.h"
#include "graph/graph.h"
#include "graph/operator.h"

// Ascend C 风格:使用编译期推导出的 Shape 分配本地张量
class GemmOp {
public:
    GemmOp(OpDesc* op_desc) : op_desc_(op_desc) {}

    // 在初始化阶段利用 Shape 推导结果分配工作内存
    bool Init() {
        auto input_x_shape = op_desc_->GetInputDesc(0).GetShape();   // [M, K]
        auto input_w_shape = op_desc_->GetInputDesc(1).GetShape();   // [K, N]
        auto output_y_shape = op_desc_->GetOutputDesc(0).GetShape(); // [M, N]

        // 通过 metadef 的 Shape 推导,上下文已经提供了精确的维度值
        // Ascend C 编译器将这些值内联到生成的代码中
        M_ = input_x_shape.GetDim(0);  // 来自推导,M > 0
        K_ = input_x_shape.GetDim(1);
        N_ = input_w_shape.GetDim(1);

        // 计算所需的内存大小(字节数)
        size_t bytes_per_element = GetDataTypeSize(DT_FLOAT);
        size_t output_size_bytes = M_ * N_ * bytes_per_element;
        size_t workspace_size_bytes = K_ * N_ * bytes_per_element;  // 中间结果

        // 分配本地张量(UB 内存池)
        output_tensor_ = ctx_->AllocTensor<TPosition::VECIN>(output_size_bytes);
        workspace_tensor_ = ctx_->AllocTensor<TPosition::VECIN>(workspace_size_bytes);

        if (!output_tensor_.is_valid() || !workspace_tensor_.is_valid()) {
            GELOGE(ACL_ERROR_GE, "Failed to allocate workspace for Gemm. "
                "Required: %zu bytes (M=%ld, N=%ld, K=%ld)",
                workspace_size_bytes, M_, N_, K_);
            return false;
        }

        GELOGI("GemmOp Init: M=%ld, K=%ld, N=%ld, workspace=%zu bytes",
               M_, K_, N_, workspace_size_bytes);
        return true;
    }

private:
    OpDesc* op_desc_;
    int64_t M_, K_, N_;
    LocalTensor<float> output_tensor_;
    LocalTensor<float> workspace_tensor_;
};

代码六:约束传播的手动干预接口

// manual_constraint_api.cpp
// 提供给高级用户手动干预约束传播的 API

#include "graph/infer_shape_context.h"
#include "graph/shape_constraint.h"

// 场景:用户知道某个维度必须满足特定约束,但自动推导无法发现
// 例如:某量化算子要求通道数必须是 32 的倍数

class ConstraintOverride {
public:
    // 添加用户自定义的维度约束
    static Status AddDimConstraint(
        InferShapeContext* ctx,
        const std::string& node_name,
        int output_idx,
        int axis,
        const std::function<bool(int64_t)>& validator,
        const std::string& error_hint
    ) {
        auto shape = ctx->GetOutputShape(output_idx);
        int64_t dim_val = shape.Dim(axis);

        if (dim_val > 0 && !validator(dim_val)) {
            return errors::InvalidArgument(
                "User-defined constraint violation on node '", node_name,
                "', output ", output_idx, ", axis ", axis,
                ": value ", dim_val, " failed validation. ", error_hint,
                " Hint: consider using a Pad op to align the dimension to "
                "the required value.");
        }
        return Status::OK();
    }

    // 强制指定某个维度的范围(用于调试或特殊硬件要求)
    static void ForceDimRange(
        InferShapeContext* ctx,
        int output_idx,
        int axis,
        int64_t min_val,
        int64_t max_val
    ) {
        ShapeConstraint constraint;
        constraint.type = ShapeConstraintType::RANGE;
        constraint.axis = axis;
        constraint.range_min = min_val;
        constraint.range_max = max_val;
        ctx->AddConstraint(output_idx, constraint);
    }
};

// 使用示例
Status QuantizedConvInferShape(InferShapeContext* ctx) {
    // ... 执行标准的 Conv Shape 推导 ...
    auto out_shape = ctx->GetOutputShape(0);

    // 量化约束:输出通道必须是 32 的倍数
    int64_t out_channels = out_shape.Dim(1);
    int64_t aligned_channels = ((out_channels + 31) / 32) * 32;

    if (out_channels % 32 != 0) {
        GELOGW("Conv output channels %ld not aligned to 32. "
               "Auto-aligning to %ld for quantized inference.",
               out_channels, aligned_channels);

        // 约束传播:通知后续节点使用对齐后的通道数
        ConstraintOverride::ForceDimRange(ctx, 0, 1,
            aligned_channels, aligned_channels);
    }

    return Status::OK();
}

代码七:整图 Shape 推导批处理调用

// batch_infer_shape.cpp
// 批量执行图中多个节点的 Shape 推导(利用并行性加速)

#include "graph/infer_shape_scheduler.h"
#include "graph/op_node.h"
#include <tbb/concurrent_unordered_map.h>
#include <tbb/parallel_for.h>

class BatchInferShapeExecutor {
public:
    explicit BatchInferShapeExecutor(Graph* graph)
        : graph_(graph), scheduler_(graph) {}

    // 批量推导入口:接受节点列表,尽可能并行执行
    Status BatchInfer(const std::vector<std::string>& node_names) {
        std::vector<OpNode*> nodes;
        for (const auto& name : node_names) {
            OpNode* node = graph_->FindNode(name);
            if (!node) {
                return errors::NotFound("Node '", name, "' not found in graph.");
            }
            nodes.push_back(node);
        }

        // Step 1: 对节点进行拓扑排序
        std::vector<OpNode*> sorted = TopologicalSort(nodes);

        // Step 2: 按拓扑序分批执行
        // 同一批内的节点没有相互依赖,可以并行推导
        std::vector<std::vector<OpNode*>> wavefronts;
        BuildWavefronts(sorted, &wavefronts);

        for (const auto& wave : wavefronts) {
            // TBB 并行执行本 wave 内所有节点的推导
            tbb::parallel_for(
                tbb::blocked_range<size_t>(0, wave.size()),
                [&](const tbb::blocked_range<size_t>& range) {
                    for (size_t i = range.begin(); i < range.end(); ++i) {
                        Status s = scheduler_.InferNode(wave[i]);
                        if (!s.IsOk()) {
                            failed_nodes_.insert({wave[i]->name(), s});
                        }
                    }
                }
            );

            // 如果本 wave 有失败节点,不继续下一 wave(确保依赖关系)
            if (!failed_nodes_.empty()) {
                break;
            }
        }

        return failed_nodes_.empty() ? Status::OK()
                                     : Status::PartialFailure(failed_nodes_);
    }

    const std::unordered_map<std::string, Status>& GetFailures() const {
        return failed_nodes_;
    }

private:
    std::vector<OpNode*> TopologicalSort(const std::vector<OpNode*>& nodes) {
        // Kahn's algorithm 拓扑排序
        std::vector<OpNode*> result;
        std::unordered_map<OpNode*, int> in_degree;
        std::queue<OpNode*> zero_degree;

        for (OpNode* n : nodes) in_degree[n] = n->input_nodes().size();

        for (auto& [n, deg] : in_degree) {
            if (deg == 0) zero_degree.push(n);
        }

        while (!zero_degree.empty()) {
            OpNode* cur = zero_degree.front(); zero_degree.pop();
            result.push_back(cur);
            for (OpNode* succ : cur->output_nodes()) {
                if (--in_degree[succ] == 0) {
                    zero_degree.push(succ);
                }
            }
        }

        return result;
    }

    void BuildWavefronts(const std::vector<OpNode*>& sorted,
                          std::vector<std::vector<OpNode*>>* wavefronts) {
        wavefronts->clear();
        std::unordered_set<OpNode*> completed;

        for (OpNode* node : sorted) {
            // 检查该节点的所有输入是否都已在之前的 wavefront 中完成
            bool all_inputs_done = true;
            for (OpNode* input : node->input_nodes()) {
                if (completed.find(input) == completed.end()) {
                    all_inputs_done = false;
                    break;
                }
            }

            if (all_inputs_done) {
                if (wavefronts->empty() ||
                    !wavefronts->back().empty()) {
                    wavefronts->push_back({});
                }
                wavefronts->back().push_back(node);
            } else {
                completed.insert(node);
            }
        }
    }

    Graph* graph_;
    InferShapeScheduler scheduler_;
    std::unordered_map<std::string, Status> failed_nodes_;
};

八、实战进阶:完整的自定义 Shape 推导插件开发流程

以下整合以上所有组件,演示从注册到调试的完整开发流程:

// plugin_development_workflow.cpp
// 完整的自定义算子 Shape 推导插件开发流程

// ====== 第一步:定义算子的元数据 ======
static const OpMetadata kMyCustomOp = {
    .name = "MyCustomOp",
    .input_names = {"data", "mask", "scale"},
    .output_names = {"output", "confidence"},
    .attr_defs = {
        AttrDef("eps", Float, 1e-5f),
        AttrDef("norm_type", String, "layer_norm"),
    },
};

// ====== 第二步:实现 Shape 推导函数 ======
Status MyCustomOpInferShape(InferShapeContext* ctx) {
    // 2.1 获取输入 Shape
    auto data_shape = ctx->GetInputShape("data");   // [B, S, H]
    auto mask_shape = ctx->GetInputShape("mask");  // [B, S]
    auto scale_shape = ctx->GetInputShape("scale"); // [H]

    const auto& data_dims = data_shape.Dims();

    // 2.2 维度一致性检查(mask 的 S 应与 data 的 S 匹配)
    if (data_dims.size() < 2 || mask_shape.Dims().size() < 2) {
        return errors::InvalidArgument(
            "MyCustomOp: data must be 3D [B,S,H], mask must be 2D [B,S]");
    }

    int64_t B = data_dims[0];
    int64_t S = data_dims[1];
    int64_t H = data_dims[2];

    // 2.3 应用用户定义的约束
    float eps = ctx->GetAttr<float>("eps");
    if (eps <= 0) {
        return errors::InvalidArgument("eps must be positive, got ", eps);
    }

    // 2.4 计算输出 Shape
    ctx->SetOutputShape("output", Shape({B, S, H}));
    ctx->SetOutputShape("confidence", Shape({B, S}));

    // 2.5 设置数据类型
    ctx->SetOutputType("output", DT_FLOAT);
    ctx->SetOutputType("confidence", DT_FLOAT);

    return Status::OK();
}

// ====== 第三步:注册到 metadef ======
METADEF_REGISTER_OP_INFER_SHAPE(kMyCustomOp, MyCustomOpInferShape);

// ====== 第四步:在模型图中使用该算子 ======
// 这一步通常由模型转换工具(如 ATC)自动处理
// 用户只需要在模型定义阶段使用该算子,Shape 推导自动触发

九、结尾

metadef 的 Shape 推导引擎是昇腾 CANN 编译器栈中一个看似基础、实则极其精妙的模块。它在编译期的精确推导,直接决定了运行时的内存使用效率与执行正确性。从约束传播到符号化维度追踪,从三色标记防循环到区间追踪防固化,每一个设计决策背后都有深刻的工程教训。

cann

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐