请添加图片描述

一、前言

昇腾CANN(Compute Architecture for Neural Networks)是围绕昇腾NPU构建的异构计算平台,覆盖从AI框架适配到算子开发的完整编译执行链路。metadef(Meta Definition)是CANN架构中的基础组件仓,为Graph Engine(GE)和算子仓库提供共享的基础数据结构和接口抽象。它定义了算子原语的统一规范和图IR(Intermediate Representation)的表示标准,是连接前端框架、图编译器与后端硬件的关键桥梁。本文将从架构视角深度解析metadef的核心设计理念。

二、为什么需要算子原语定义层

2.1 前端框架的算子语义差异

AI前端框架(MindSpore、PyTorch、TensorFlow)各自定义了不同的算子集合和语义规范。以卷积操作为例:

框架 API名称 数据排布默认值 Padding语义
PyTorch torch.nn.Conv2d NCHW same/valid/显式
TensorFlow tf.nn.conv2d NHWC SAME/VALID/显式
MindSpore mindspore.nn.Conv2d NCHW same/valid/pad

同一个数学语义的算子,在不同框架中可能有不同的名称、参数顺序、数据排布偏好和默认行为。这种差异导致:

  1. 语义鸿沟:框架A的Conv2d(3,64,3,padding=1)与框架B的Conv2d(3,64,3,padding=“SAME”)即使数学含义相同,其IR表示也完全不同。
  2. 属性不兼容:有些框架的算子属性是命名参数,有些是位置参数,属性的默认值和约束条件各不相同。
  3. Shape推导规则不一致:相同的输入Shape,不同框架的推导结果可能因Padding规则不同而有差异。

2.2 跨框架对齐的根本问题

要解决上述差异,传统做法是为每个框架编写独立的图编译器后端,但这会导致N个框架×M个后端的复杂度爆炸。metadef通过引入统一的算子原语定义层,将问题转化为:

N个前端框架 → metadef统一原语定义 → M个后端设备

把 N×M 的复杂度降为 N+M。这一层的核心使命是:

  • 定义标准算子原语(Operator Primitives):一套与框架无关、与硬件无关的算子规范。
  • 建立属性规范:统一属性命名、类型、默认值和约束条件。
  • 规定数据类型系统:包括DataType枚举、Format枚举、Shape表达方式。
  • 提供Shape推导契约:每个算子注册时声明其输入输出Shape的推导函数。

这正是metadef存在的根本原因——它不是某个框架的附属品,而是CANN全栈的"通用语言层"。

三、metadef 的核心设计

3.1 算子原语集合

metadef通过OpRegistrationData类提供算子注册机制,开发者可以声明算子的类型、输入输出、属性和推导函数。这套机制构成了CANN生态中所有算子管理的基础。

代码块 1:算子原语注册 - 通过 OpRegistrationData 声明一个自定义算子

#include "register/register.h"

using namespace ge;

// 注册名为 "CustomConv2d" 的算子原语
OpRegistrationData customConvReg("CustomConv2d");

// 定义输入:input, weight
customConvReg.Input("input")
             .ParamType(REQUIRED)
             .DataType({DT_FLOAT16, DT_FLOAT});
customConvReg.Input("weight")
             .ParamType(REQUIRED)
             .DataType({DT_FLOAT16, DT_FLOAT});

// 定义输出:output
customConvReg.Output("output")
             .DataType({DT_FLOAT16, DT_FLOAT});

// 定义属性:stride, pad, dilation
customConvReg.Attr("stride", ListInt::CreateList({1, 1}));
customConvReg.Attr("pad", ListInt::CreateList({0, 0, 0, 0}));
customConvReg.Attr("dilation", ListInt::CreateList({1, 1}));
customConvReg.Attr("group", AttrValue::INTEGER(1));

代码块 2:查询已注册的算子原语信息

#include "register/register.h"
#include <iostream>

void QueryOpInfo(const std::string& opType) {
    auto* regData = OpRegistrationData::GetInstance().GetOpRegistrationData(opType);
    if (regData == nullptr) {
        std::cerr << "算子 " << opType << " 未注册" << std::endl;
        return;
    }

    std::cout << "算子类型: " << opType << std::endl;

    // 遍历输入
    for (const auto& input : regData->GetInputs()) {
        std::cout << "  输入: " << input.GetName()
                  << ", 类型数: " << input.GetDataType().size() << std::endl;
    }

    // 遍历输出
    for (const auto& output : regData->GetOutputs()) {
        std::cout << "  输出: " << output.GetName()
                  << ", 类型数: " << output.GetDataType().size() << std::endl;
    }

    // 遍历属性
    for (const auto& attr : regData->GetAttrs()) {
        std::cout << "  属性: " << attr.GetName()
                  << ", 默认值类型: " << attr.GetDefaultValue().GetValueType() << std::endl;
    }
}

3.2 属性系统

metadef的属性系统支持多种基础数据类型,并为每个属性提供默认值和校验机制。属性定义是算子原语规范的重要组成。

代码块 3:属性类型定义与校验

#include "graph/attr_value.h"
#include "graph/utils/attr_utils.h"
#include <iostream>

void AttrValidationDemo() {
    // 创建各种属性值
    auto intAttr = AttrValue::Create(42);
    auto floatAttr = AttrValue::Create(3.14f);
    auto strAttr = AttrValue::Create("RELU");
    auto listIntAttr = AttrValue::Create(std::vector<int64_t>({1, 2, 3}));
    auto boolAttr = AttrValue::Create(true);

    std::cout << "intAttr value: " << intAttr.GetInt() << std::endl;
    std::cout << "floatAttr value: " << floatAttr.GetFloat() << std::endl;
    std::cout << "strAttr value: " << strAttr.GetString() << std::endl;

    // 类型校验:尝试从 int 读取 float 会失败
    if (intAttr.GetValueType() == AttrValue::VT_FLOAT) {
        std::cout << "int is float: " << intAttr.GetFloat() << std::endl;
    } else {
        std::cout << "intAttr is not float type (VT: "
                  << static_cast<int>(intAttr.GetValueType()) << ")" << std::endl;
    }
}

3.3 Shape 函数

每个算子原语注册时可以绑定InferShape函数,用于在编译期根据输入Shape推导输出Shape。这是图编译中必不可少的一环。

代码块 4:InferShape 函数注册

#include "graph/operator_reg.h"
#include "graph/infer_shape_context.h"

IMPLEMT_COMMON_INFERFUNC(CustomConv2dInferShape) {
    // 获取输入 shape
    auto inputShape = op.GetInputDescByName("input").GetShape();
    auto weightShape = op.GetInputDescByName("weight").GetShape();

    // 解析属性
    auto stride = op.GetAttr("stride")->GetListInt();
    auto pad = op.GetAttr("pad")->GetListInt();

    // NCHW 格式:input = [N, C, H, W], weight = [K, C, R, S]
    auto N = inputShape.GetDim(0);
    auto H = inputShape.GetDim(2);
    auto W = inputShape.GetDim(3);
    auto K = weightShape.GetDim(0);
    auto R = weightShape.GetDim(2);
    auto S = weightShape.GetDim(3);

    // 计算输出 shape
    auto outH = (H + 2 * pad[0] - R) / stride[0] + 1;
    auto outW = (W + 2 * pad[1] - S) / stride[1] + 1;

    // 设置输出 shape
    ge::Shape outShape({N, K, outH, outW});
    op.GetOutputDescByName("output").SetShape(outShape);

    return GRAPH_SUCCESS;
}

// 将 InferShape 函数注册到算子原语
MAKE_CUSTOM_OP(CustomConv2d, CustomConv2dInferShape)

代码块 5:Shape 查询与操作

#include "graph/tensor.h"
#include "graph/shape.h"
#include <iostream>

void ShapeManipulation() {
    // 构造一个 4D Shape [N, C, H, W]
    ge::Shape shape({1, 3, 224, 224});
    std::cout << "Shape dims: " << shape.GetDimNum() << std::endl;
    std::cout << "Dim 0 (N): " << shape.GetDim(0) << std::endl;
    std::cout << "Dim 1 (C): " << shape.GetDim(1) << std::endl;
    std::cout << "Dim 2 (H): " << shape.GetDim(2) << std::endl;
    std::cout << "Dim 3 (W): " << shape.GetDim(3) << std::endl;

    // 获取 shape 的连续表示
    auto dims = shape.GetDims();
    for (size_t i = 0; i < dims.size(); ++i) {
        std::cout << "  dim[" << i << "] = " << dims[i] << std::endl;
    }

    // 获取总元素数
    int64_t total = 1;
    for (auto d : dims) {
        if (d > 0) total *= d;
    }
    std::cout << "Total elements: " << total << std::endl;
}

3.4 数据类型规范

metadef统一管理数据类型(DataType)和数据排布格式(Format),这是IR表示的核心基础。

代码块 6:数据类型枚举与格式查询

#include "graph/types.h"
#include "graph/utils/type_utils.h"
#include <iostream>

void DataTypeDemo() {
    // 查询数据类型的字符串表示
    std::cout << "DT_FLOAT  -> " << TypeUtils::DataTypeToSerialString(DT_FLOAT) << std::endl;
    std::cout << "DT_FLOAT16 -> " << TypeUtils::DataTypeToSerialString(DT_FLOAT16) << std::endl;
    std::cout << "DT_INT8   -> " << TypeUtils::DataTypeToSerialString(DT_INT8) << std::endl;
    std::cout << "DT_BF16   -> " << TypeUtils::DataTypeToSerialString(DT_BF16) << std::endl;

    // 查询数据格式的字符串表示
    std::cout << "FORMAT_NCHW -> " << TypeUtils::FormatToSerialString(FORMAT_NCHW) << std::endl;
    std::cout << "FORMAT_NHWC -> " << TypeUtils::FormatToSerialString(FORMAT_NHWC) << std::endl;
    std::cout << "FORMAT_ND   -> " << TypeUtils::FormatToSerialString(FORMAT_ND) << std::endl;
    std::cout << "FORMAT_NC1HWC0 -> " << TypeUtils::FormatToSerialString(FORMAT_NC1HWC0) << std::endl;

    // 根据字符串反查枚举
    ge::DataType dt;
    TypeUtils::SerialStringToDataType("DT_FLOAT", dt);
    std::cout << "Parsed DT_FLOAT -> enum value: " << static_cast<int>(dt) << std::endl;
}

void SizeInBytesDemo() {
    // 计算指定数据类型和元素数占用的内存大小
    int64_t totalBytes = GetSizeInBytes(1024, DT_FLOAT);
    std::cout << "1024 floats occupy " << totalBytes << " bytes" << std::endl;

    auto size = GetSizeByDataType(DT_FLOAT16);
    std::cout << "float16 occupies " << size << " bytes per element" << std::endl;
}

四、图 IR 统一表示的实现

4.1 三层转换架构

CANN的图编译流程可以抽象为三个层级:

计算图 (前端框架感知)
    │
    ▼
Parser 层: 将框架计算的图解析为 metadef IR
    │
    ▼
metadef IR (统一中间表示,算子以原语形式表达)
    │
    ▼
Graph Compiler: 消费 metadef IR,进行图优化、算子选择、内存安排
    │
    ▼
多后端 Lowering: 设备无关优化 → 设备相关优化 → 可执行指令

第一层:Parser(框架解析)。CANN通过AscendCL等接口接收来自前端框架的模型,Parser层负责将框架特有的图表示(如MindSpore的ANF图、PyTorch的TorchScript、TensorFlow的GraphDef)解析为metadef IR。这一层完成了框架算子到metadef原语的映射。

第二层:metadef IR(统一中间表示)。这是核心抽象层,所有框架的算子统一用metadef定义的算子原语、TensorDesc、Shape、AttrValue等数据结构表达。IR不再包含任何框架特定的语义,只保留与计算相关的本质信息。

第三层:Lowering(多后端分发)。Graph Compiler消费统一的metadef IR,进行算子选择、内存优化、图切分等操作,最终lower到具体的硬件后端执行。

代码块 7:TensorDesc 构建 - 构造 metadef IR 的核心数据结构

#include "graph/tensor.h"
#include "graph/types.h"
#include <iostream>

void BuildTensorDesc() {
    // 创建一个 TensorDesc
    ge::TensorDesc desc;

    // 设置 shape
    ge::Shape shape({1, 3, 224, 224});
    desc.SetShape(shape);

    // 设置数据类型
    desc.SetDataType(ge::DT_FLOAT);

    // 设置数据排布格式
    desc.SetFormat(ge::FORMAT_NCHW);

    // 设置名称
    desc.SetName("input_tensor");

    // 查询
    std::cout << "Tensor name: " << desc.GetName() << std::endl;
    std::cout << "DataType: " << static_cast<int>(desc.GetDataType()) << std::endl;
    std::cout << "Format: " << static_cast<int>(desc.GetFormat()) << std::endl;
    std::cout << "Shape dims: " << desc.GetShape().GetDimNum() << std::endl;

    // 或者构建时直接初始化
    ge::TensorDesc initDesc(
        ge::Shape({1024, 512}),
        ge::FORMAT_ND,
        ge::DT_FLOAT16
    );
    std::cout << "InitDesc format: "
              << static_cast<int>(initDesc.GetFormat()) << std::endl;
}

代码块 8:构建 metadef 计算图(简化示意)

#include "graph/graph.h"
#include "graph/op_desc.h"
#include "graph/operator.h"
#include <iostream>

void BuildComputeGraph() {
    // 创建计算图
    ge::Graph graph("sample_graph");

    // 创建算子描述
    auto addOpDesc = std::make_shared<ge::OpDesc>("add_1", "Add");
    auto mulOpDesc = std::make_shared<ge::OpDesc>("mul_1", "Mul");

    // 为 Add 算子设置输入输出 TensorDesc
    addOpDesc->AddInputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
    addOpDesc->AddInputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
    addOpDesc->AddOutputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));

    // 为 Mul 算子设置输入输出 TensorDesc
    mulOpDesc->AddInputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
    mulOpDesc->AddInputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
    mulOpDesc->AddOutputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));

    // 创建算子节点
    auto addNode = graph.AddNode(addOpDesc);
    auto mulNode = graph.AddNode(mulOpDesc);

    // 添加边:add 的输出连接到 mul 的第一个输入
    graph.AddEdge(addNode, 0, mulNode, 0);

    std::cout << "计算图创建完成,包含 "
              << graph.GetAllNodes().size() << " 个节点" << std::endl;
}

4.2 IR 中的数据格式转换

在metadef IR层面,数据排布格式的转换是一个关键能力。昇腾NPU内部使用5D格式(NC1HWC0),而前端框架通常使用4D格式(NCHW/NHWC)。metadef通过TypeUtils系列函数管理这种格式转换。

代码块 9:Format 转换查询

#include "graph/types.h"
#include <iostream>
#include <cstring>

void FormatConversionDemo() {
    // CANN 支持多种数据格式,通过 C0/SubFormat 机制管理
    ge::Format fmt = FORMAT_NC1HWC0;

    // 5D 格式 NC1HWC0 的物理含义:
    // N: batch, C1: C // 16, H: height, W: width, C0: 16 (对齐单位)
    int64_t N = 1, C = 32, H = 56, W = 56, C0 = 16;
    int64_t C1 = (C + C0 - 1) / C0;

    int64_t nchwSize = N * C * H * W;
    int64_t nc1hwc0Size = N * C1 * H * W * C0;

    std::cout << "NCHW storage: " << nchwSize << " elements" << std::endl;
    std::cout << "NC1HWC0 storage: " << nc1hwc0Size << " elements (aligned)" << std::endl;
    std::cout << "Format overhead: " << (nc1hwc0Size - nchwSize) << " elements" << std::endl;
}

五、与 ONNX / MLIR 的对比

5.1 metadef vs ONNX

对比维度 ONNX metadef
定位 跨框架模型交换格式 CANN 架构内部中间表示
算子定义 固定算子集(opsets版本化) 可扩展算子原语注册机制
与硬件关系 框架中立、硬件中立 专为昇腾NPU优化
Shape推导 可选、依赖框架 强制注册InferShape函数
属性系统 协议缓冲区(protobuf)序列化 C++原生类型+AttrValue
数据格式 NCHW/NHWC 等标准格式 含NC1HWC0等昇腾特有格式

metadef比ONNX更贴近硬件的原因是:ONNX的设计目标是跨平台交换,因此其算子定义是最小公共子集,不包含特定硬件的信息。而metadef的设计目标是CANN内部的编译优化,因此:

  • 支持昇腾特有格式:NC1HWC0、FRACTAL_NZ等5D格式直接内建在IR中。
  • 支持硬件特性表达:算子原语定义中可以直接声明对Ascend C Kernel的需求。
  • 紧耦合的推导链:InferShape、InferDataType、InferFormat形成一个完整的编译期推导链。

5.2 metadef vs MLIR

对比维度 MLIR metadef
定位 可扩展多级IR编译框架 面向图编译的原语定义层
框架体重 重型,支持Dialect自定义 轻量,聚焦算子定义和Graph
学习曲线 陡峭(Dialect、Pass、Pattern) 平缓(C++类+注册机制)
适用场景 复杂编译场景、多硬件 CANN图编译和闭管场景
工程复杂度 高,需Dialect设计 低,直接扩展OpRegistrationData

metadef比MLIR更轻量的原因在于:MLIR推崇通过Dialect机制表达所有计算语义,但其代价是框架本身的复杂度(每个Dialect需要定义Type、Attribute、Operation、Pattern Rewrite等)。metadef选择了一个更务实的方案:将算子原语定义集中在OpRegistrationData一套API中,Shape推导通过注册InferShape函数完成,不需要额外的模式匹配框架。

5.3 综合定位

metadef在技术路线上与ONNX和MLIR既有交集又有差异:

  • 与ONNX的交集:都定义了算子规格、数据类型和属性系统。差异在于ONNX面向交换,metadef面向编译。
  • 与MLIR的交集:都提供中间表示。差异在于MLIR的Dialect框架是一种可复用的IR设计范式,metadef是一种面向CANN场景的专用IR。

一句话总结:metadef比ONNX更贴近硬件、比MLIR更轻量,是CANN场景下的最优选择。

六、metadef 在 CANN 编译流水线中的位置

6.1 编译流水线全景

前端框架模型
    │
    ▼
┌─────────────────────────────────────┐
│  AscendCL / Framework Parser        │
│  (将框架算子映射为metadef算子原语)   │
└──────────┬──────────────────────────┘
           │ metadef IR (由 TensorDesc, OpDesc, Graph 构成)
           ▼
┌─────────────────────────────────────┐
│  Graph Compiler (GE)                │
│  - 图优化:常量折叠、死代码消除     │
│  - 算子选择:选择最优Ascend C Kernel │
│  - 内存安排:地址分配、复用         │
│  - 图切分:NPU/CPU协同              │
└──────────┬──────────────────────────┘
           │ 优化后的 metadef IR
           ▼
┌─────────────────────────────────────┐
│  Lowering: 多后端分发               │
│  → 设备相关优化                     │
│  → 生成可执行指令序列               │
└─────────────────────────────────────┘

Parser消费前端框架的图,输出metadef IR。Graph Compiler(GE)消费metadef IR,进行图优化和算子选择。metadef在其中发挥了契约层的作用——Parser按照metadef的算子原语规范产出IR,GE按照同样的规范消费IR。

6.2 从 IR 查询到执行上下文

在编译流水线中,metadef的InferShape上下文(InferShapeContext)和Tiling上下文(TilingContext)分别在编译期和算机选择阶段被使用。

代码块 10:从 IR 算子中查询属性并进行 Shape 推导

#include "graph/infer_shape_context.h"
#include <iostream>

ge::graphStatus InferShapeForPooling(InferShapeContext& ctx) {
    // 获取输入 tensor 描述
    auto inputDesc = ctx.GetInputDesc(0);
    auto inputShape = inputDesc.GetShape();

    if (inputShape.GetDimNum() != 4) {
        std::cerr << "Expected 4D input, got " << inputShape.GetDimNum() << "D" << std::endl;
        return GRAPH_FAILED;
    }

    // 解析 pooling 参数
    auto kernelSize = ctx.GetAttr("kernel_size")->GetListInt();
    auto stride = ctx.GetAttr("strides")->GetListInt();
    auto padMode = ctx.GetAttr("padding")->GetString();

    auto N = inputShape.GetDim(0);
    auto C = inputShape.GetDim(1);
    auto H = inputShape.GetDim(2);
    auto W = inputShape.GetDim(3);

    int64_t outH, outW;
    if (padMode == "SAME") {
        outH = (H + stride[0] - 1) / stride[0];
        outW = (W + stride[1] - 1) / stride[1];
    } else {
        outH = (H - kernelSize[0]) / stride[0] + 1;
        outW = (W - kernelSize[1]) / stride[1] + 1;
    }

    ge::Shape outputShape({N, C, outH, outW});
    ctx.SetOutputShape(0, outputShape);

    return GRAPH_SUCCESS;
}

代码块 11:利用 Range 和 StorageShape 进行动态 Shape 推导

#include "exe_graph/runtime/range.h"
#include "exe_graph/runtime/storage_shape.h"
#include <iostream>

void DynamicShapeDemo() {
    // Range 描述一个动态范围 [min, max]
    gert::Range batchRange(1, 64);

    // StorageShape 同时存储 origin shape 和 storage shape
    gert::StorageShape storageShape(
        ge::Shape({-1, 3, 224, 224}),   // origin shape(含动态维度 -1)
        ge::Shape({1, 3, 224, 224})     // storage shape(运行时实际值)
    );

    // 设置 CompileTimeTensorDesc
    gert::CompileTimeTensorDesc compileDesc;
    compileDesc.dataType = ge::DT_FLOAT;
    compileDesc.format = ge::FORMAT_NCHW;

    std::cout << "Compile-time tensor: dtype="
              << static_cast<int>(compileDesc.dataType)
              << ", format=" << static_cast<int>(compileDesc.format) << std::endl;
}

代码块 12:属性校验 - 检查属性的合法值

#include "graph/attr_value.h"
#include <iostream>
#include <set>

bool ValidateConvAttrs(const std::map<std::string, ge::AttrValue>& attrs) {
    bool valid = true;

    // 校验 group 参数:必须 >= 1
    auto groupIt = attrs.find("group");
    if (groupIt != attrs.end()) {
        int64_t group = groupIt->second.GetInt();
        if (group < 1) {
            std::cerr << "group must >= 1, got " << group << std::endl;
            valid = false;
        }
    }

    // 校验 padding 参数:必须是 4 个非负整数 [top, bottom, left, right]
    auto padIt = attrs.find("pad");
    if (padIt != attrs.end()) {
        auto pad = padIt->second.GetListInt();
        if (pad.size() != 4) {
            std::cerr << "pad must have 4 values, got " << pad.size() << std::endl;
            valid = false;
        }
        for (auto v : pad) {
            if (v < 0) {
                std::cerr << "pad values must be non-negative, got " << v << std::endl;
                valid = false;
            }
        }
    }

    // 校验 activation 参数:必须是有效激活函数名
    std::set<std::string> validActivations = {"RELU", "SIGMOID", "TANH", "GELU"};
    auto actIt = attrs.find("activation");
    if (actIt != attrs.end()) {
        std::string act = actIt->second.GetString();
        if (validActivations.find(act) == validActivations.end()) {
            std::cerr << "Invalid activation: " << act << std::endl;
            valid = false;
        }
    }

    return valid;
}

int main() {
    std::map<std::string, ge::AttrValue> attrs;
    attrs["group"] = ge::AttrValue::Create(1);
    attrs["pad"] = ge::AttrValue::Create(std::vector<int64_t>({0, 0, 0, 0}));
    attrs["activation"] = ge::AttrValue::Create("RELU");
    attrs["dilation"] = ge::AttrValue::Create(std::vector<int64_t>({1, 1}));

    bool ok = ValidateConvAttrs(attrs);
    std::cout << "Attribute validation: " << (ok ? "PASS" : "FAIL") << std::endl;

    // 测试非法值
    attrs["group"] = ge::AttrValue::Create(0); // 非法 group
    ok = ValidateConvAttrs(attrs);
    std::cout << "Validation after bad group: " << (ok ? "PASS" : "FAIL") << std::endl;

    return 0;
}

七、2 个关键陷阱与解决方案

陷阱一:原语定义不完整导致后端 lowering 失败

现象:自定义算子成功注册了输入输出和InferShape,但在图编译到后端时,lowering阶段报错"no kernel found for op"。

根因:算子原语定义中缺少了关键信息,例如:

  • 没有注册Tiling函数(编译期决定算子的分块参数)
  • Format约束没有声明(后端不知道输入应该是NC1HWC0还是NCHW)
  • 缺少数据类型的约束(后端无法匹配到对应的Kernel实现)

代码块 13:完整的算子原语注册(含 Tiling 和 Format 约束)

#include "register/register.h"
#include "register/op_impl_registry.h"
#include "graph/infer_shape_context.h"

using namespace ge;

// 1. 注册算子的 InferShape
IMPLEMT_COMMON_INFERFUNC(MyOpInferShape) {
    auto inputShape = op.GetInputDesc(0).GetShape();
    // 假设输出 shape 与输入相同
    op.SetOutputShape(0, inputShape);
    return GRAPH_SUCCESS;
}

// 2. 注册算子的 Tiling 函数
IMPLEMT_COMMON_TILINGFUNC(MyOpTiling) {
    // 读取输入信息以确定分块参数
    auto inputShape = tiling.GetInputDesc(0).GetShape();

    // 查询算子属性
    auto blockSize = tiling.GetAttr("block_size")->GetInt();

    // 设置 Tiling 参数
    int64_t tileSize = 256; // 根据硬件特性选择分块大小
    tiling.SetTilingKey(0); // 选择 Kernel 变体

    // 将 Tiling 参数写入二进制缓存
    std::vector<char> tilingData(sizeof(int64_t) * 2);
    // ... 填充 tilingData
    tiling.SetTilingData(tilingData.data(), tilingData.size());

    return GRAPH_SUCCESS;
}

// 3. 注册完整算子定义
IMPLEMT_COMMON_OP(MyOp, MyOpInferShape, MyOpTiling) {
    // 输入描述:声明支持的数据类型和格式
    op.Input("x")
       .ParamType(REQUIRED)
       .DataType({DT_FLOAT16, DT_FLOAT})
       .Format({FORMAT_NCHW});
    op.Input("w")
       .ParamType(REQUIRED)
       .DataType({DT_FLOAT16, DT_FLOAT})
       .Format({FORMAT_NCHW});

    // 输出描述
    op.Output("y")
      .DataType({DT_FLOAT16, DT_FLOAT});

    // 属性
    op.Attr("block_size", AttrValue::INTEGER(32));
    op.Attr("precision", AttrValue::STRING("high"));

    return op;
}

解决方案:完整的算子原语必须包含:

  1. 输入输出的DataType和Format约束
  2. InferShape函数
  3. Tiling函数(对于需要分块的算子)
  4. 所有属性的声明

陷阱二:IR 表示歧义导致等价性判断错误

现象:图优化Pass中,两个计算图结构完全相同(相同的op类型、相同的输入Shape、相同的属性),但等价性判断返回false。导致图优化跳过,生成的执行序列缺少一部分。

根因:metadef IR中某些表示存在歧义。例如:

  • DataType枚举值在跨组件序列化/反序列化时因为版本不同产生偏差
  • Format 含Component Format时(如C0/SubFormat),表面Format相同但实际C0不同
  • 动态Shape的表示方式不统一:有的位置用-1表示动态轴,有的位置用Range表示

代码块 14:等价性判断辅助函数(消除歧义)

#include "graph/tensor.h"
#include "graph/operator.h"
#include <iostream>

// 标准化的等价性比较:不考虑 C0/SubFormat 的细微差异
bool AreTensorDescsEquivalent(const ge::TensorDesc& a, const ge::TensorDesc& b) {
    // 1. 比较 Shape(将动态轴统一为 -1)
    auto shapeA = a.GetShape();
    auto shapeB = b.GetShape();

    if (shapeA.GetDimNum() != shapeB.GetDimNum()) {
        return false;
    }

    for (size_t i = 0; i < shapeA.GetDimNum(); ++i) {
        auto dimA = shapeA.GetDim(i);
        auto dimB = shapeB.GetDim(i);
        if (dimA >= 0 && dimB >= 0 && dimA != dimB) {
            return false;
        }
        // 一方为 -1 另一方为正数 → 相同(动态维度匹配)
    }

    // 2. 比较 DataType(严格相等)
    if (a.GetDataType() != b.GetDataType()) {
        return false;
    }

    // 3. 比较 Format:只比较主 Format,忽略 C0/SubFormat
    auto primaryA = GetPrimaryFormat(a.GetFormat());
    auto primaryB = GetPrimaryFormat(b.GetFormat());
    if (primaryA != primaryB) {
        return false;
    }

    return true;
}

// 更严格的等价性:考虑所有维度
bool AreTensorDescsStrictlyEqual(const ge::TensorDesc& a, const ge::TensorDesc& b) {
    auto shapeA = a.GetShape();
    auto shapeB = b.GetShape();
    if (shapeA.GetDims() != shapeB.GetDims()) return false;
    if (a.GetDataType() != b.GetDataType()) return false;
    if (a.GetFormat() != b.GetFormat()) return false;
    return true;
}

代码块 15:IR 序列化与反序列化中处理版本兼容性

#include <iostream>
#include <sstream>
#include <cstring>
#include <cstdint>

// 序列化 TensorDesc 到字符串(含版本标记)
std::string SerializeTensorDesc(const ge::TensorDesc& desc, uint32_t version = 1) {
    std::ostringstream oss;
    auto shape = desc.GetShape();

    // 写入版本号
    oss.write(reinterpret_cast<const char*>(&version), sizeof(version));

    // 写入维度数
    uint64_t dimNum = shape.GetDimNum();
    oss.write(reinterpret_cast<const char*>(&dimNum), sizeof(dimNum));

    // 写入每个维度
    for (size_t i = 0; i < dimNum; ++i) {
        int64_t dim = shape.GetDim(i);
        oss.write(reinterpret_cast<const char*>(&dim), sizeof(dim));
    }

    // 写入 datatype
    int dt = static_cast<int>(desc.GetDataType());
    oss.write(reinterpret_cast<const char*>(&dt), sizeof(dt));

    // 写入 format(使用主 format 以保证跨版本兼容)
    auto mainFormat = GetPrimaryFormat(desc.GetFormat());
    int fmt = static_cast<int>(mainFormat);
    oss.write(reinterpret_cast<const char*>(&fmt), sizeof(fmt));

    return oss.str();
}

// 反序列化 TensorDesc(版本感知)
ge::TensorDesc DeserializeTensorDesc(const std::string& data) {
    const char* ptr = data.data();

    // 读取版本号
    uint32_t version;
    std::memcpy(&version, ptr, sizeof(version));
    ptr += sizeof(version);

    // 读取维度数
    uint64_t dimNum;
    std::memcpy(&dimNum, ptr, sizeof(dimNum));
    ptr += sizeof(dimNum);

    // 读取维度
    std::vector<int64_t> dims(dimNum);
    for (size_t i = 0; i < dimNum; ++i) {
        std::memcpy(&dims[i], ptr, sizeof(dims[i]));
        ptr += sizeof(dims[i]);
    }

    // 读取 datatype
    int dt;
    std::memcpy(&dt, ptr, sizeof(dt));
    ptr += sizeof(dt);

    // 读取 format
    int fmt;
    std::memcpy(&fmt, ptr, sizeof(fmt));

    ge::TensorDesc desc;
    desc.SetShape(ge::Shape(dims));
    desc.SetDataType(static_cast<ge::DataType>(dt));
    desc.SetFormat(static_cast<ge::Format>(fmt));

    std::cout << "Deserialized from version " << version
              << ": shape=[" << (dimNum > 0 ? std::to_string(dims[0]) : "")
              << (dimNum > 1 ? "," + std::to_string(dims[1]) : "")
              << (dimNum > 2 ? "," + std::to_string(dims[2]) : "")
              << (dimNum > 3 ? "," + std::to_string(dims[3]) : "")
              << "], dtype=" << dt << ", fmt=" << fmt << std::endl;

    return desc;
}

解决方案

  1. 等价性判断时区分"表面对等"和"语义对等",对C0/SubFormat采用主Format比较。
  2. IR序列化时带上版本号,反序列化时做版本适配。
  3. 动态Shape比较时统一将-1视为"任意正数"。
  4. 关键位置使用GetPrimaryFormatGetSubFormat显式指定比较粒度。

八、实战:完整 IR 转换与算子管理脚本

代码块 16:遍历计算图并输出 IR 结构

#include "graph/graph.h"
#include "graph/op_desc.h"
#include <iostream>

void DumpGraphIR(const ge::Graph& graph) {
    auto nodes = graph.GetAllNodes();

    std::cout << "========== Graph IR Dump ==========" << std::endl;
    std::cout << "Node count: " << nodes.size() << std::endl;
    std::cout << "===================================" << std::endl;

    for (const auto& node : nodes) {
        auto opDesc = node.GetOpDesc();
        if (opDesc == nullptr) continue;

        std::cout << "[Node] name=" << opDesc->GetName()
                  << ", type=" << opDesc->GetType() << std::endl;

        // 输入信息
        auto inputs = opDesc->GetInputsDesc().GetAllInputDescs();
        for (size_t i = 0; i < inputs.size(); ++i) {
            auto& input = inputs[i];
            std::cout << "  Input[" << i << "]: "
                      << "name=" << input.GetName()
                      << ", dtype=" << static_cast<int>(input.GetDataType())
                      << ", shape=[";
            auto shape = input.GetShape();
            for (size_t d = 0; d < shape.GetDimNum(); ++d) {
                if (d > 0) std::cout << ",";
                std::cout << shape.GetDim(d);
            }
            std::cout << "]" << std::endl;
        }

        // 输出信息
        auto outputs = opDesc->GetOutputsDesc().GetAllOutputDescs();
        for (size_t i = 0; i < outputs.size(); ++i) {
            auto& output = outputs[i];
            std::cout << "  Output[" << i << "]: "
                      << "name=" << output.GetName()
                      << ", dtype=" << static_cast<int>(output.GetDataType())
                      << ", shape=[";
            auto shape = output.GetShape();
            for (size_t d = 0; d < shape.GetDimNum(); ++d) {
                if (d > 0) std::cout << ",";
                std::cout << shape.GetDim(d);
            }
            std::cout << "]" << std::endl;
        }

        // 属性信息
        auto attrs = opDesc->GetAllAttrs();
        for (const auto& pair : attrs) {
            auto attrValue = pair.second;
            std::cout << "  Attr[" << pair.first << "]: type="
                      << static_cast<int>(attrValue.GetValueType()) << std::endl;
        }
    }
}

代码块 17:IR 转换脚本 - 从框架图到 metadef IR

#!/usr/bin/env python3
"""
简化的 IR 转换示意脚本:模拟从框架图到 metadef IR 的转换。
实际 CANN 中通过 C++ 完成,此处用 Python 说明思想。
"""

class MetaTensorDesc:
    """metadef 中的 TensorDesc 对应物"""
    def __init__(self, name, shape, dtype, fmt="NCHW"):
        self.name = name
        self.shape = shape
        self.dtype = dtype
        self.format = fmt

    def __repr__(self):
        return (f"TensorDesc(name={self.name}, shape={self.shape}, "
                f"dtype={self.dtype}, format={self.format})")

class MetaOpDesc:
    """metadef 中的 OpDesc 对应物"""
    def __init__(self, name, op_type):
        self.name = name
        self.type = op_type
        self.inputs = []
        self.outputs = []
        self.attrs = {}

    def add_input(self, tensor_desc):
        self.inputs.append(tensor_desc)

    def add_output(self, tensor_desc):
        self.outputs.append(tensor_desc)

    def add_attr(self, key, value):
        self.attrs[key] = value

    def __repr__(self):
        return (f"OpDesc(name={self.name}, type={self.type}, "
                f"inputs={len(self.inputs)}, outputs={len(self.outputs)}, "
                f"attrs={self.attrs})")


class MetaIRGraph:
    """metadef IR 的计算图"""
    def __init__(self, name):
        self.name = name
        self.nodes = []
        self.edges = []  # (src_node_idx, src_out_idx, dst_node_idx, dst_in_idx)

    def add_node(self, op_desc):
        self.nodes.append(op_desc)
        return len(self.nodes) - 1

    def add_edge(self, src, src_out, dst, dst_in):
        self.edges.append((src, src_out, dst, dst_in))

    def dump(self):
        print(f"=== MetaIR Graph: {self.name} ===")
        print(f"Nodes: {len(self.nodes)}, Edges: {len(self.edges)}")
        for i, node in enumerate(self.nodes):
            print(f"  Node[{i}]: {node}")
        for src, sout, dst, din in self.edges:
            print(f"  Edge: Node[{src}].out[{sout}] -> Node[{dst}].in[{din}]")


# === 模拟框架解析器 ===
class Parser:
    """将框架模型解析为 metadef IR"""

    @staticmethod
    def from_framework(model_def):
        """model_def 是简化后的框架模型表示"""
        ir = MetaIRGraph("converted_model")

        # 为每个算子创建 metadef OpDesc
        for op_spec in model_def["ops"]:
            op = MetaOpDesc(op_spec["name"], op_spec["type"])

            # 输入
            for inp in op_spec["inputs"]:
                op.add_input(MetaTensorDesc(
                    inp["name"], inp["shape"], inp["dtype"], inp.get("fmt", "NCHW")
                ))

            # 输出
            for out in op_spec["outputs"]:
                op.add_output(MetaTensorDesc(
                    out["name"], out["shape"], out["dtype"], out.get("fmt", "NCHW")
                ))

            # 属性
            if "attrs" in op_spec:
                for k, v in op_spec["attrs"].items():
                    op.add_attr(k, v)

            idx = ir.add_node(op)

            # 建立连接(简化:按顺序连接)
            if idx > 0:
                prev_output_count = len(ir.nodes[idx - 1].outputs)
                for o in range(prev_output_count):
                    ir.add_edge(idx - 1, o, idx, o)

        return ir


# === 测试转换 ===
if __name__ == "__main__":
    model = {
        "ops": [
            {
                "name": "conv_1",
                "type": "Conv2D",
                "inputs": [
                    {"name": "x", "shape": [1, 3, 224, 224], "dtype": "DT_FLOAT"},
                    {"name": "w", "shape": [64, 3, 3, 3], "dtype": "DT_FLOAT"},
                ],
                "outputs": [
                    {"name": "y", "shape": [1, 64, 112, 112], "dtype": "DT_FLOAT"},
                ],
                "attrs": {"stride": [2, 2], "pad": [1, 1, 1, 1], "group": 1},
            },
            {
                "name": "relu_1",
                "type": "Relu",
                "inputs": [
                    {"name": "x", "shape": [1, 64, 112, 112], "dtype": "DT_FLOAT"},
                ],
                "outputs": [
                    {"name": "y", "shape": [1, 64, 112, 112], "dtype": "DT_FLOAT"},
                ],
                "attrs": {},
            },
            {
                "name": "pool_1",
                "type": "MaxPool",
                "inputs": [
                    {"name": "x", "shape": [1, 64, 112, 112], "dtype": "DT_FLOAT"},
                ],
                "outputs": [
                    {"name": "y", "shape": [1, 64, 56, 56], "dtype": "DT_FLOAT"},
                ],
                "attrs": {"kernel_size": [2, 2], "strides": [2, 2]},
            },
        ]
    }

    ir = Parser.from_framework(model)
    ir.dump()

九、总结

metadef作为CANN架构中的基础组件,承担着算子原语定义和图IR统一表示的双重使命。它通过OpRegistrationData、TensorDesc、Shape、AttrValue等基础数据结构,构建了一套与前端框架无关、与底层硬件紧密适配的中间表示体系。

理解metadef的设计思路,有助于开发者:

  1. 在算子开发时:遵循metadef的原语规范注册算子,确保算子在CANN生态中可被识别和管理。
  2. 在图编译时:理解metadef IR的结构,正确查询和修改IR以获得最优的编译结果。
  3. 在跨框架适配时:利用metadef的统一原语层,一次适配即可服务多个前端框架。
Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐