CANN metadef:算子原语定义与图IR的统一表示

文章目录
一、前言
昇腾CANN(Compute Architecture for Neural Networks)是围绕昇腾NPU构建的异构计算平台,覆盖从AI框架适配到算子开发的完整编译执行链路。metadef(Meta Definition)是CANN架构中的基础组件仓,为Graph Engine(GE)和算子仓库提供共享的基础数据结构和接口抽象。它定义了算子原语的统一规范和图IR(Intermediate Representation)的表示标准,是连接前端框架、图编译器与后端硬件的关键桥梁。本文将从架构视角深度解析metadef的核心设计理念。
二、为什么需要算子原语定义层
2.1 前端框架的算子语义差异
AI前端框架(MindSpore、PyTorch、TensorFlow)各自定义了不同的算子集合和语义规范。以卷积操作为例:
| 框架 | API名称 | 数据排布默认值 | Padding语义 |
|---|---|---|---|
| PyTorch | torch.nn.Conv2d | NCHW | same/valid/显式 |
| TensorFlow | tf.nn.conv2d | NHWC | SAME/VALID/显式 |
| MindSpore | mindspore.nn.Conv2d | NCHW | same/valid/pad |
同一个数学语义的算子,在不同框架中可能有不同的名称、参数顺序、数据排布偏好和默认行为。这种差异导致:
- 语义鸿沟:框架A的Conv2d(3,64,3,padding=1)与框架B的Conv2d(3,64,3,padding=“SAME”)即使数学含义相同,其IR表示也完全不同。
- 属性不兼容:有些框架的算子属性是命名参数,有些是位置参数,属性的默认值和约束条件各不相同。
- Shape推导规则不一致:相同的输入Shape,不同框架的推导结果可能因Padding规则不同而有差异。
2.2 跨框架对齐的根本问题
要解决上述差异,传统做法是为每个框架编写独立的图编译器后端,但这会导致N个框架×M个后端的复杂度爆炸。metadef通过引入统一的算子原语定义层,将问题转化为:
N个前端框架 → metadef统一原语定义 → M个后端设备
把 N×M 的复杂度降为 N+M。这一层的核心使命是:
- 定义标准算子原语(Operator Primitives):一套与框架无关、与硬件无关的算子规范。
- 建立属性规范:统一属性命名、类型、默认值和约束条件。
- 规定数据类型系统:包括DataType枚举、Format枚举、Shape表达方式。
- 提供Shape推导契约:每个算子注册时声明其输入输出Shape的推导函数。
这正是metadef存在的根本原因——它不是某个框架的附属品,而是CANN全栈的"通用语言层"。
三、metadef 的核心设计
3.1 算子原语集合
metadef通过OpRegistrationData类提供算子注册机制,开发者可以声明算子的类型、输入输出、属性和推导函数。这套机制构成了CANN生态中所有算子管理的基础。
代码块 1:算子原语注册 - 通过 OpRegistrationData 声明一个自定义算子
#include "register/register.h"
using namespace ge;
// 注册名为 "CustomConv2d" 的算子原语
OpRegistrationData customConvReg("CustomConv2d");
// 定义输入:input, weight
customConvReg.Input("input")
.ParamType(REQUIRED)
.DataType({DT_FLOAT16, DT_FLOAT});
customConvReg.Input("weight")
.ParamType(REQUIRED)
.DataType({DT_FLOAT16, DT_FLOAT});
// 定义输出:output
customConvReg.Output("output")
.DataType({DT_FLOAT16, DT_FLOAT});
// 定义属性:stride, pad, dilation
customConvReg.Attr("stride", ListInt::CreateList({1, 1}));
customConvReg.Attr("pad", ListInt::CreateList({0, 0, 0, 0}));
customConvReg.Attr("dilation", ListInt::CreateList({1, 1}));
customConvReg.Attr("group", AttrValue::INTEGER(1));
代码块 2:查询已注册的算子原语信息
#include "register/register.h"
#include <iostream>
void QueryOpInfo(const std::string& opType) {
auto* regData = OpRegistrationData::GetInstance().GetOpRegistrationData(opType);
if (regData == nullptr) {
std::cerr << "算子 " << opType << " 未注册" << std::endl;
return;
}
std::cout << "算子类型: " << opType << std::endl;
// 遍历输入
for (const auto& input : regData->GetInputs()) {
std::cout << " 输入: " << input.GetName()
<< ", 类型数: " << input.GetDataType().size() << std::endl;
}
// 遍历输出
for (const auto& output : regData->GetOutputs()) {
std::cout << " 输出: " << output.GetName()
<< ", 类型数: " << output.GetDataType().size() << std::endl;
}
// 遍历属性
for (const auto& attr : regData->GetAttrs()) {
std::cout << " 属性: " << attr.GetName()
<< ", 默认值类型: " << attr.GetDefaultValue().GetValueType() << std::endl;
}
}
3.2 属性系统
metadef的属性系统支持多种基础数据类型,并为每个属性提供默认值和校验机制。属性定义是算子原语规范的重要组成。
代码块 3:属性类型定义与校验
#include "graph/attr_value.h"
#include "graph/utils/attr_utils.h"
#include <iostream>
void AttrValidationDemo() {
// 创建各种属性值
auto intAttr = AttrValue::Create(42);
auto floatAttr = AttrValue::Create(3.14f);
auto strAttr = AttrValue::Create("RELU");
auto listIntAttr = AttrValue::Create(std::vector<int64_t>({1, 2, 3}));
auto boolAttr = AttrValue::Create(true);
std::cout << "intAttr value: " << intAttr.GetInt() << std::endl;
std::cout << "floatAttr value: " << floatAttr.GetFloat() << std::endl;
std::cout << "strAttr value: " << strAttr.GetString() << std::endl;
// 类型校验:尝试从 int 读取 float 会失败
if (intAttr.GetValueType() == AttrValue::VT_FLOAT) {
std::cout << "int is float: " << intAttr.GetFloat() << std::endl;
} else {
std::cout << "intAttr is not float type (VT: "
<< static_cast<int>(intAttr.GetValueType()) << ")" << std::endl;
}
}
3.3 Shape 函数
每个算子原语注册时可以绑定InferShape函数,用于在编译期根据输入Shape推导输出Shape。这是图编译中必不可少的一环。
代码块 4:InferShape 函数注册
#include "graph/operator_reg.h"
#include "graph/infer_shape_context.h"
IMPLEMT_COMMON_INFERFUNC(CustomConv2dInferShape) {
// 获取输入 shape
auto inputShape = op.GetInputDescByName("input").GetShape();
auto weightShape = op.GetInputDescByName("weight").GetShape();
// 解析属性
auto stride = op.GetAttr("stride")->GetListInt();
auto pad = op.GetAttr("pad")->GetListInt();
// NCHW 格式:input = [N, C, H, W], weight = [K, C, R, S]
auto N = inputShape.GetDim(0);
auto H = inputShape.GetDim(2);
auto W = inputShape.GetDim(3);
auto K = weightShape.GetDim(0);
auto R = weightShape.GetDim(2);
auto S = weightShape.GetDim(3);
// 计算输出 shape
auto outH = (H + 2 * pad[0] - R) / stride[0] + 1;
auto outW = (W + 2 * pad[1] - S) / stride[1] + 1;
// 设置输出 shape
ge::Shape outShape({N, K, outH, outW});
op.GetOutputDescByName("output").SetShape(outShape);
return GRAPH_SUCCESS;
}
// 将 InferShape 函数注册到算子原语
MAKE_CUSTOM_OP(CustomConv2d, CustomConv2dInferShape)
代码块 5:Shape 查询与操作
#include "graph/tensor.h"
#include "graph/shape.h"
#include <iostream>
void ShapeManipulation() {
// 构造一个 4D Shape [N, C, H, W]
ge::Shape shape({1, 3, 224, 224});
std::cout << "Shape dims: " << shape.GetDimNum() << std::endl;
std::cout << "Dim 0 (N): " << shape.GetDim(0) << std::endl;
std::cout << "Dim 1 (C): " << shape.GetDim(1) << std::endl;
std::cout << "Dim 2 (H): " << shape.GetDim(2) << std::endl;
std::cout << "Dim 3 (W): " << shape.GetDim(3) << std::endl;
// 获取 shape 的连续表示
auto dims = shape.GetDims();
for (size_t i = 0; i < dims.size(); ++i) {
std::cout << " dim[" << i << "] = " << dims[i] << std::endl;
}
// 获取总元素数
int64_t total = 1;
for (auto d : dims) {
if (d > 0) total *= d;
}
std::cout << "Total elements: " << total << std::endl;
}
3.4 数据类型规范
metadef统一管理数据类型(DataType)和数据排布格式(Format),这是IR表示的核心基础。
代码块 6:数据类型枚举与格式查询
#include "graph/types.h"
#include "graph/utils/type_utils.h"
#include <iostream>
void DataTypeDemo() {
// 查询数据类型的字符串表示
std::cout << "DT_FLOAT -> " << TypeUtils::DataTypeToSerialString(DT_FLOAT) << std::endl;
std::cout << "DT_FLOAT16 -> " << TypeUtils::DataTypeToSerialString(DT_FLOAT16) << std::endl;
std::cout << "DT_INT8 -> " << TypeUtils::DataTypeToSerialString(DT_INT8) << std::endl;
std::cout << "DT_BF16 -> " << TypeUtils::DataTypeToSerialString(DT_BF16) << std::endl;
// 查询数据格式的字符串表示
std::cout << "FORMAT_NCHW -> " << TypeUtils::FormatToSerialString(FORMAT_NCHW) << std::endl;
std::cout << "FORMAT_NHWC -> " << TypeUtils::FormatToSerialString(FORMAT_NHWC) << std::endl;
std::cout << "FORMAT_ND -> " << TypeUtils::FormatToSerialString(FORMAT_ND) << std::endl;
std::cout << "FORMAT_NC1HWC0 -> " << TypeUtils::FormatToSerialString(FORMAT_NC1HWC0) << std::endl;
// 根据字符串反查枚举
ge::DataType dt;
TypeUtils::SerialStringToDataType("DT_FLOAT", dt);
std::cout << "Parsed DT_FLOAT -> enum value: " << static_cast<int>(dt) << std::endl;
}
void SizeInBytesDemo() {
// 计算指定数据类型和元素数占用的内存大小
int64_t totalBytes = GetSizeInBytes(1024, DT_FLOAT);
std::cout << "1024 floats occupy " << totalBytes << " bytes" << std::endl;
auto size = GetSizeByDataType(DT_FLOAT16);
std::cout << "float16 occupies " << size << " bytes per element" << std::endl;
}
四、图 IR 统一表示的实现
4.1 三层转换架构
CANN的图编译流程可以抽象为三个层级:
计算图 (前端框架感知)
│
▼
Parser 层: 将框架计算的图解析为 metadef IR
│
▼
metadef IR (统一中间表示,算子以原语形式表达)
│
▼
Graph Compiler: 消费 metadef IR,进行图优化、算子选择、内存安排
│
▼
多后端 Lowering: 设备无关优化 → 设备相关优化 → 可执行指令
第一层:Parser(框架解析)。CANN通过AscendCL等接口接收来自前端框架的模型,Parser层负责将框架特有的图表示(如MindSpore的ANF图、PyTorch的TorchScript、TensorFlow的GraphDef)解析为metadef IR。这一层完成了框架算子到metadef原语的映射。
第二层:metadef IR(统一中间表示)。这是核心抽象层,所有框架的算子统一用metadef定义的算子原语、TensorDesc、Shape、AttrValue等数据结构表达。IR不再包含任何框架特定的语义,只保留与计算相关的本质信息。
第三层:Lowering(多后端分发)。Graph Compiler消费统一的metadef IR,进行算子选择、内存优化、图切分等操作,最终lower到具体的硬件后端执行。
代码块 7:TensorDesc 构建 - 构造 metadef IR 的核心数据结构
#include "graph/tensor.h"
#include "graph/types.h"
#include <iostream>
void BuildTensorDesc() {
// 创建一个 TensorDesc
ge::TensorDesc desc;
// 设置 shape
ge::Shape shape({1, 3, 224, 224});
desc.SetShape(shape);
// 设置数据类型
desc.SetDataType(ge::DT_FLOAT);
// 设置数据排布格式
desc.SetFormat(ge::FORMAT_NCHW);
// 设置名称
desc.SetName("input_tensor");
// 查询
std::cout << "Tensor name: " << desc.GetName() << std::endl;
std::cout << "DataType: " << static_cast<int>(desc.GetDataType()) << std::endl;
std::cout << "Format: " << static_cast<int>(desc.GetFormat()) << std::endl;
std::cout << "Shape dims: " << desc.GetShape().GetDimNum() << std::endl;
// 或者构建时直接初始化
ge::TensorDesc initDesc(
ge::Shape({1024, 512}),
ge::FORMAT_ND,
ge::DT_FLOAT16
);
std::cout << "InitDesc format: "
<< static_cast<int>(initDesc.GetFormat()) << std::endl;
}
代码块 8:构建 metadef 计算图(简化示意)
#include "graph/graph.h"
#include "graph/op_desc.h"
#include "graph/operator.h"
#include <iostream>
void BuildComputeGraph() {
// 创建计算图
ge::Graph graph("sample_graph");
// 创建算子描述
auto addOpDesc = std::make_shared<ge::OpDesc>("add_1", "Add");
auto mulOpDesc = std::make_shared<ge::OpDesc>("mul_1", "Mul");
// 为 Add 算子设置输入输出 TensorDesc
addOpDesc->AddInputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
addOpDesc->AddInputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
addOpDesc->AddOutputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
// 为 Mul 算子设置输入输出 TensorDesc
mulOpDesc->AddInputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
mulOpDesc->AddInputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
mulOpDesc->AddOutputDesc(ge::TensorDesc(ge::Shape({4, 4}), ge::FORMAT_ND, ge::DT_FLOAT));
// 创建算子节点
auto addNode = graph.AddNode(addOpDesc);
auto mulNode = graph.AddNode(mulOpDesc);
// 添加边:add 的输出连接到 mul 的第一个输入
graph.AddEdge(addNode, 0, mulNode, 0);
std::cout << "计算图创建完成,包含 "
<< graph.GetAllNodes().size() << " 个节点" << std::endl;
}
4.2 IR 中的数据格式转换
在metadef IR层面,数据排布格式的转换是一个关键能力。昇腾NPU内部使用5D格式(NC1HWC0),而前端框架通常使用4D格式(NCHW/NHWC)。metadef通过TypeUtils系列函数管理这种格式转换。
代码块 9:Format 转换查询
#include "graph/types.h"
#include <iostream>
#include <cstring>
void FormatConversionDemo() {
// CANN 支持多种数据格式,通过 C0/SubFormat 机制管理
ge::Format fmt = FORMAT_NC1HWC0;
// 5D 格式 NC1HWC0 的物理含义:
// N: batch, C1: C // 16, H: height, W: width, C0: 16 (对齐单位)
int64_t N = 1, C = 32, H = 56, W = 56, C0 = 16;
int64_t C1 = (C + C0 - 1) / C0;
int64_t nchwSize = N * C * H * W;
int64_t nc1hwc0Size = N * C1 * H * W * C0;
std::cout << "NCHW storage: " << nchwSize << " elements" << std::endl;
std::cout << "NC1HWC0 storage: " << nc1hwc0Size << " elements (aligned)" << std::endl;
std::cout << "Format overhead: " << (nc1hwc0Size - nchwSize) << " elements" << std::endl;
}
五、与 ONNX / MLIR 的对比
5.1 metadef vs ONNX
| 对比维度 | ONNX | metadef |
|---|---|---|
| 定位 | 跨框架模型交换格式 | CANN 架构内部中间表示 |
| 算子定义 | 固定算子集(opsets版本化) | 可扩展算子原语注册机制 |
| 与硬件关系 | 框架中立、硬件中立 | 专为昇腾NPU优化 |
| Shape推导 | 可选、依赖框架 | 强制注册InferShape函数 |
| 属性系统 | 协议缓冲区(protobuf)序列化 | C++原生类型+AttrValue |
| 数据格式 | NCHW/NHWC 等标准格式 | 含NC1HWC0等昇腾特有格式 |
metadef比ONNX更贴近硬件的原因是:ONNX的设计目标是跨平台交换,因此其算子定义是最小公共子集,不包含特定硬件的信息。而metadef的设计目标是CANN内部的编译优化,因此:
- 支持昇腾特有格式:NC1HWC0、FRACTAL_NZ等5D格式直接内建在IR中。
- 支持硬件特性表达:算子原语定义中可以直接声明对Ascend C Kernel的需求。
- 紧耦合的推导链:InferShape、InferDataType、InferFormat形成一个完整的编译期推导链。
5.2 metadef vs MLIR
| 对比维度 | MLIR | metadef |
|---|---|---|
| 定位 | 可扩展多级IR编译框架 | 面向图编译的原语定义层 |
| 框架体重 | 重型,支持Dialect自定义 | 轻量,聚焦算子定义和Graph |
| 学习曲线 | 陡峭(Dialect、Pass、Pattern) | 平缓(C++类+注册机制) |
| 适用场景 | 复杂编译场景、多硬件 | CANN图编译和闭管场景 |
| 工程复杂度 | 高,需Dialect设计 | 低,直接扩展OpRegistrationData |
metadef比MLIR更轻量的原因在于:MLIR推崇通过Dialect机制表达所有计算语义,但其代价是框架本身的复杂度(每个Dialect需要定义Type、Attribute、Operation、Pattern Rewrite等)。metadef选择了一个更务实的方案:将算子原语定义集中在OpRegistrationData一套API中,Shape推导通过注册InferShape函数完成,不需要额外的模式匹配框架。
5.3 综合定位
metadef在技术路线上与ONNX和MLIR既有交集又有差异:
- 与ONNX的交集:都定义了算子规格、数据类型和属性系统。差异在于ONNX面向交换,metadef面向编译。
- 与MLIR的交集:都提供中间表示。差异在于MLIR的Dialect框架是一种可复用的IR设计范式,metadef是一种面向CANN场景的专用IR。
一句话总结:metadef比ONNX更贴近硬件、比MLIR更轻量,是CANN场景下的最优选择。
六、metadef 在 CANN 编译流水线中的位置
6.1 编译流水线全景
前端框架模型
│
▼
┌─────────────────────────────────────┐
│ AscendCL / Framework Parser │
│ (将框架算子映射为metadef算子原语) │
└──────────┬──────────────────────────┘
│ metadef IR (由 TensorDesc, OpDesc, Graph 构成)
▼
┌─────────────────────────────────────┐
│ Graph Compiler (GE) │
│ - 图优化:常量折叠、死代码消除 │
│ - 算子选择:选择最优Ascend C Kernel │
│ - 内存安排:地址分配、复用 │
│ - 图切分:NPU/CPU协同 │
└──────────┬──────────────────────────┘
│ 优化后的 metadef IR
▼
┌─────────────────────────────────────┐
│ Lowering: 多后端分发 │
│ → 设备相关优化 │
│ → 生成可执行指令序列 │
└─────────────────────────────────────┘
Parser消费前端框架的图,输出metadef IR。Graph Compiler(GE)消费metadef IR,进行图优化和算子选择。metadef在其中发挥了契约层的作用——Parser按照metadef的算子原语规范产出IR,GE按照同样的规范消费IR。
6.2 从 IR 查询到执行上下文
在编译流水线中,metadef的InferShape上下文(InferShapeContext)和Tiling上下文(TilingContext)分别在编译期和算机选择阶段被使用。
代码块 10:从 IR 算子中查询属性并进行 Shape 推导
#include "graph/infer_shape_context.h"
#include <iostream>
ge::graphStatus InferShapeForPooling(InferShapeContext& ctx) {
// 获取输入 tensor 描述
auto inputDesc = ctx.GetInputDesc(0);
auto inputShape = inputDesc.GetShape();
if (inputShape.GetDimNum() != 4) {
std::cerr << "Expected 4D input, got " << inputShape.GetDimNum() << "D" << std::endl;
return GRAPH_FAILED;
}
// 解析 pooling 参数
auto kernelSize = ctx.GetAttr("kernel_size")->GetListInt();
auto stride = ctx.GetAttr("strides")->GetListInt();
auto padMode = ctx.GetAttr("padding")->GetString();
auto N = inputShape.GetDim(0);
auto C = inputShape.GetDim(1);
auto H = inputShape.GetDim(2);
auto W = inputShape.GetDim(3);
int64_t outH, outW;
if (padMode == "SAME") {
outH = (H + stride[0] - 1) / stride[0];
outW = (W + stride[1] - 1) / stride[1];
} else {
outH = (H - kernelSize[0]) / stride[0] + 1;
outW = (W - kernelSize[1]) / stride[1] + 1;
}
ge::Shape outputShape({N, C, outH, outW});
ctx.SetOutputShape(0, outputShape);
return GRAPH_SUCCESS;
}
代码块 11:利用 Range 和 StorageShape 进行动态 Shape 推导
#include "exe_graph/runtime/range.h"
#include "exe_graph/runtime/storage_shape.h"
#include <iostream>
void DynamicShapeDemo() {
// Range 描述一个动态范围 [min, max]
gert::Range batchRange(1, 64);
// StorageShape 同时存储 origin shape 和 storage shape
gert::StorageShape storageShape(
ge::Shape({-1, 3, 224, 224}), // origin shape(含动态维度 -1)
ge::Shape({1, 3, 224, 224}) // storage shape(运行时实际值)
);
// 设置 CompileTimeTensorDesc
gert::CompileTimeTensorDesc compileDesc;
compileDesc.dataType = ge::DT_FLOAT;
compileDesc.format = ge::FORMAT_NCHW;
std::cout << "Compile-time tensor: dtype="
<< static_cast<int>(compileDesc.dataType)
<< ", format=" << static_cast<int>(compileDesc.format) << std::endl;
}
代码块 12:属性校验 - 检查属性的合法值
#include "graph/attr_value.h"
#include <iostream>
#include <set>
bool ValidateConvAttrs(const std::map<std::string, ge::AttrValue>& attrs) {
bool valid = true;
// 校验 group 参数:必须 >= 1
auto groupIt = attrs.find("group");
if (groupIt != attrs.end()) {
int64_t group = groupIt->second.GetInt();
if (group < 1) {
std::cerr << "group must >= 1, got " << group << std::endl;
valid = false;
}
}
// 校验 padding 参数:必须是 4 个非负整数 [top, bottom, left, right]
auto padIt = attrs.find("pad");
if (padIt != attrs.end()) {
auto pad = padIt->second.GetListInt();
if (pad.size() != 4) {
std::cerr << "pad must have 4 values, got " << pad.size() << std::endl;
valid = false;
}
for (auto v : pad) {
if (v < 0) {
std::cerr << "pad values must be non-negative, got " << v << std::endl;
valid = false;
}
}
}
// 校验 activation 参数:必须是有效激活函数名
std::set<std::string> validActivations = {"RELU", "SIGMOID", "TANH", "GELU"};
auto actIt = attrs.find("activation");
if (actIt != attrs.end()) {
std::string act = actIt->second.GetString();
if (validActivations.find(act) == validActivations.end()) {
std::cerr << "Invalid activation: " << act << std::endl;
valid = false;
}
}
return valid;
}
int main() {
std::map<std::string, ge::AttrValue> attrs;
attrs["group"] = ge::AttrValue::Create(1);
attrs["pad"] = ge::AttrValue::Create(std::vector<int64_t>({0, 0, 0, 0}));
attrs["activation"] = ge::AttrValue::Create("RELU");
attrs["dilation"] = ge::AttrValue::Create(std::vector<int64_t>({1, 1}));
bool ok = ValidateConvAttrs(attrs);
std::cout << "Attribute validation: " << (ok ? "PASS" : "FAIL") << std::endl;
// 测试非法值
attrs["group"] = ge::AttrValue::Create(0); // 非法 group
ok = ValidateConvAttrs(attrs);
std::cout << "Validation after bad group: " << (ok ? "PASS" : "FAIL") << std::endl;
return 0;
}
七、2 个关键陷阱与解决方案
陷阱一:原语定义不完整导致后端 lowering 失败
现象:自定义算子成功注册了输入输出和InferShape,但在图编译到后端时,lowering阶段报错"no kernel found for op"。
根因:算子原语定义中缺少了关键信息,例如:
- 没有注册Tiling函数(编译期决定算子的分块参数)
- Format约束没有声明(后端不知道输入应该是NC1HWC0还是NCHW)
- 缺少数据类型的约束(后端无法匹配到对应的Kernel实现)
代码块 13:完整的算子原语注册(含 Tiling 和 Format 约束)
#include "register/register.h"
#include "register/op_impl_registry.h"
#include "graph/infer_shape_context.h"
using namespace ge;
// 1. 注册算子的 InferShape
IMPLEMT_COMMON_INFERFUNC(MyOpInferShape) {
auto inputShape = op.GetInputDesc(0).GetShape();
// 假设输出 shape 与输入相同
op.SetOutputShape(0, inputShape);
return GRAPH_SUCCESS;
}
// 2. 注册算子的 Tiling 函数
IMPLEMT_COMMON_TILINGFUNC(MyOpTiling) {
// 读取输入信息以确定分块参数
auto inputShape = tiling.GetInputDesc(0).GetShape();
// 查询算子属性
auto blockSize = tiling.GetAttr("block_size")->GetInt();
// 设置 Tiling 参数
int64_t tileSize = 256; // 根据硬件特性选择分块大小
tiling.SetTilingKey(0); // 选择 Kernel 变体
// 将 Tiling 参数写入二进制缓存
std::vector<char> tilingData(sizeof(int64_t) * 2);
// ... 填充 tilingData
tiling.SetTilingData(tilingData.data(), tilingData.size());
return GRAPH_SUCCESS;
}
// 3. 注册完整算子定义
IMPLEMT_COMMON_OP(MyOp, MyOpInferShape, MyOpTiling) {
// 输入描述:声明支持的数据类型和格式
op.Input("x")
.ParamType(REQUIRED)
.DataType({DT_FLOAT16, DT_FLOAT})
.Format({FORMAT_NCHW});
op.Input("w")
.ParamType(REQUIRED)
.DataType({DT_FLOAT16, DT_FLOAT})
.Format({FORMAT_NCHW});
// 输出描述
op.Output("y")
.DataType({DT_FLOAT16, DT_FLOAT});
// 属性
op.Attr("block_size", AttrValue::INTEGER(32));
op.Attr("precision", AttrValue::STRING("high"));
return op;
}
解决方案:完整的算子原语必须包含:
- 输入输出的DataType和Format约束
- InferShape函数
- Tiling函数(对于需要分块的算子)
- 所有属性的声明
陷阱二:IR 表示歧义导致等价性判断错误
现象:图优化Pass中,两个计算图结构完全相同(相同的op类型、相同的输入Shape、相同的属性),但等价性判断返回false。导致图优化跳过,生成的执行序列缺少一部分。
根因:metadef IR中某些表示存在歧义。例如:
- DataType枚举值在跨组件序列化/反序列化时因为版本不同产生偏差
- Format 含Component Format时(如C0/SubFormat),表面Format相同但实际C0不同
- 动态Shape的表示方式不统一:有的位置用-1表示动态轴,有的位置用Range表示
代码块 14:等价性判断辅助函数(消除歧义)
#include "graph/tensor.h"
#include "graph/operator.h"
#include <iostream>
// 标准化的等价性比较:不考虑 C0/SubFormat 的细微差异
bool AreTensorDescsEquivalent(const ge::TensorDesc& a, const ge::TensorDesc& b) {
// 1. 比较 Shape(将动态轴统一为 -1)
auto shapeA = a.GetShape();
auto shapeB = b.GetShape();
if (shapeA.GetDimNum() != shapeB.GetDimNum()) {
return false;
}
for (size_t i = 0; i < shapeA.GetDimNum(); ++i) {
auto dimA = shapeA.GetDim(i);
auto dimB = shapeB.GetDim(i);
if (dimA >= 0 && dimB >= 0 && dimA != dimB) {
return false;
}
// 一方为 -1 另一方为正数 → 相同(动态维度匹配)
}
// 2. 比较 DataType(严格相等)
if (a.GetDataType() != b.GetDataType()) {
return false;
}
// 3. 比较 Format:只比较主 Format,忽略 C0/SubFormat
auto primaryA = GetPrimaryFormat(a.GetFormat());
auto primaryB = GetPrimaryFormat(b.GetFormat());
if (primaryA != primaryB) {
return false;
}
return true;
}
// 更严格的等价性:考虑所有维度
bool AreTensorDescsStrictlyEqual(const ge::TensorDesc& a, const ge::TensorDesc& b) {
auto shapeA = a.GetShape();
auto shapeB = b.GetShape();
if (shapeA.GetDims() != shapeB.GetDims()) return false;
if (a.GetDataType() != b.GetDataType()) return false;
if (a.GetFormat() != b.GetFormat()) return false;
return true;
}
代码块 15:IR 序列化与反序列化中处理版本兼容性
#include <iostream>
#include <sstream>
#include <cstring>
#include <cstdint>
// 序列化 TensorDesc 到字符串(含版本标记)
std::string SerializeTensorDesc(const ge::TensorDesc& desc, uint32_t version = 1) {
std::ostringstream oss;
auto shape = desc.GetShape();
// 写入版本号
oss.write(reinterpret_cast<const char*>(&version), sizeof(version));
// 写入维度数
uint64_t dimNum = shape.GetDimNum();
oss.write(reinterpret_cast<const char*>(&dimNum), sizeof(dimNum));
// 写入每个维度
for (size_t i = 0; i < dimNum; ++i) {
int64_t dim = shape.GetDim(i);
oss.write(reinterpret_cast<const char*>(&dim), sizeof(dim));
}
// 写入 datatype
int dt = static_cast<int>(desc.GetDataType());
oss.write(reinterpret_cast<const char*>(&dt), sizeof(dt));
// 写入 format(使用主 format 以保证跨版本兼容)
auto mainFormat = GetPrimaryFormat(desc.GetFormat());
int fmt = static_cast<int>(mainFormat);
oss.write(reinterpret_cast<const char*>(&fmt), sizeof(fmt));
return oss.str();
}
// 反序列化 TensorDesc(版本感知)
ge::TensorDesc DeserializeTensorDesc(const std::string& data) {
const char* ptr = data.data();
// 读取版本号
uint32_t version;
std::memcpy(&version, ptr, sizeof(version));
ptr += sizeof(version);
// 读取维度数
uint64_t dimNum;
std::memcpy(&dimNum, ptr, sizeof(dimNum));
ptr += sizeof(dimNum);
// 读取维度
std::vector<int64_t> dims(dimNum);
for (size_t i = 0; i < dimNum; ++i) {
std::memcpy(&dims[i], ptr, sizeof(dims[i]));
ptr += sizeof(dims[i]);
}
// 读取 datatype
int dt;
std::memcpy(&dt, ptr, sizeof(dt));
ptr += sizeof(dt);
// 读取 format
int fmt;
std::memcpy(&fmt, ptr, sizeof(fmt));
ge::TensorDesc desc;
desc.SetShape(ge::Shape(dims));
desc.SetDataType(static_cast<ge::DataType>(dt));
desc.SetFormat(static_cast<ge::Format>(fmt));
std::cout << "Deserialized from version " << version
<< ": shape=[" << (dimNum > 0 ? std::to_string(dims[0]) : "")
<< (dimNum > 1 ? "," + std::to_string(dims[1]) : "")
<< (dimNum > 2 ? "," + std::to_string(dims[2]) : "")
<< (dimNum > 3 ? "," + std::to_string(dims[3]) : "")
<< "], dtype=" << dt << ", fmt=" << fmt << std::endl;
return desc;
}
解决方案:
- 等价性判断时区分"表面对等"和"语义对等",对C0/SubFormat采用主Format比较。
- IR序列化时带上版本号,反序列化时做版本适配。
- 动态Shape比较时统一将-1视为"任意正数"。
- 关键位置使用
GetPrimaryFormat和GetSubFormat显式指定比较粒度。
八、实战:完整 IR 转换与算子管理脚本
代码块 16:遍历计算图并输出 IR 结构
#include "graph/graph.h"
#include "graph/op_desc.h"
#include <iostream>
void DumpGraphIR(const ge::Graph& graph) {
auto nodes = graph.GetAllNodes();
std::cout << "========== Graph IR Dump ==========" << std::endl;
std::cout << "Node count: " << nodes.size() << std::endl;
std::cout << "===================================" << std::endl;
for (const auto& node : nodes) {
auto opDesc = node.GetOpDesc();
if (opDesc == nullptr) continue;
std::cout << "[Node] name=" << opDesc->GetName()
<< ", type=" << opDesc->GetType() << std::endl;
// 输入信息
auto inputs = opDesc->GetInputsDesc().GetAllInputDescs();
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
std::cout << " Input[" << i << "]: "
<< "name=" << input.GetName()
<< ", dtype=" << static_cast<int>(input.GetDataType())
<< ", shape=[";
auto shape = input.GetShape();
for (size_t d = 0; d < shape.GetDimNum(); ++d) {
if (d > 0) std::cout << ",";
std::cout << shape.GetDim(d);
}
std::cout << "]" << std::endl;
}
// 输出信息
auto outputs = opDesc->GetOutputsDesc().GetAllOutputDescs();
for (size_t i = 0; i < outputs.size(); ++i) {
auto& output = outputs[i];
std::cout << " Output[" << i << "]: "
<< "name=" << output.GetName()
<< ", dtype=" << static_cast<int>(output.GetDataType())
<< ", shape=[";
auto shape = output.GetShape();
for (size_t d = 0; d < shape.GetDimNum(); ++d) {
if (d > 0) std::cout << ",";
std::cout << shape.GetDim(d);
}
std::cout << "]" << std::endl;
}
// 属性信息
auto attrs = opDesc->GetAllAttrs();
for (const auto& pair : attrs) {
auto attrValue = pair.second;
std::cout << " Attr[" << pair.first << "]: type="
<< static_cast<int>(attrValue.GetValueType()) << std::endl;
}
}
}
代码块 17:IR 转换脚本 - 从框架图到 metadef IR
#!/usr/bin/env python3
"""
简化的 IR 转换示意脚本:模拟从框架图到 metadef IR 的转换。
实际 CANN 中通过 C++ 完成,此处用 Python 说明思想。
"""
class MetaTensorDesc:
"""metadef 中的 TensorDesc 对应物"""
def __init__(self, name, shape, dtype, fmt="NCHW"):
self.name = name
self.shape = shape
self.dtype = dtype
self.format = fmt
def __repr__(self):
return (f"TensorDesc(name={self.name}, shape={self.shape}, "
f"dtype={self.dtype}, format={self.format})")
class MetaOpDesc:
"""metadef 中的 OpDesc 对应物"""
def __init__(self, name, op_type):
self.name = name
self.type = op_type
self.inputs = []
self.outputs = []
self.attrs = {}
def add_input(self, tensor_desc):
self.inputs.append(tensor_desc)
def add_output(self, tensor_desc):
self.outputs.append(tensor_desc)
def add_attr(self, key, value):
self.attrs[key] = value
def __repr__(self):
return (f"OpDesc(name={self.name}, type={self.type}, "
f"inputs={len(self.inputs)}, outputs={len(self.outputs)}, "
f"attrs={self.attrs})")
class MetaIRGraph:
"""metadef IR 的计算图"""
def __init__(self, name):
self.name = name
self.nodes = []
self.edges = [] # (src_node_idx, src_out_idx, dst_node_idx, dst_in_idx)
def add_node(self, op_desc):
self.nodes.append(op_desc)
return len(self.nodes) - 1
def add_edge(self, src, src_out, dst, dst_in):
self.edges.append((src, src_out, dst, dst_in))
def dump(self):
print(f"=== MetaIR Graph: {self.name} ===")
print(f"Nodes: {len(self.nodes)}, Edges: {len(self.edges)}")
for i, node in enumerate(self.nodes):
print(f" Node[{i}]: {node}")
for src, sout, dst, din in self.edges:
print(f" Edge: Node[{src}].out[{sout}] -> Node[{dst}].in[{din}]")
# === 模拟框架解析器 ===
class Parser:
"""将框架模型解析为 metadef IR"""
@staticmethod
def from_framework(model_def):
"""model_def 是简化后的框架模型表示"""
ir = MetaIRGraph("converted_model")
# 为每个算子创建 metadef OpDesc
for op_spec in model_def["ops"]:
op = MetaOpDesc(op_spec["name"], op_spec["type"])
# 输入
for inp in op_spec["inputs"]:
op.add_input(MetaTensorDesc(
inp["name"], inp["shape"], inp["dtype"], inp.get("fmt", "NCHW")
))
# 输出
for out in op_spec["outputs"]:
op.add_output(MetaTensorDesc(
out["name"], out["shape"], out["dtype"], out.get("fmt", "NCHW")
))
# 属性
if "attrs" in op_spec:
for k, v in op_spec["attrs"].items():
op.add_attr(k, v)
idx = ir.add_node(op)
# 建立连接(简化:按顺序连接)
if idx > 0:
prev_output_count = len(ir.nodes[idx - 1].outputs)
for o in range(prev_output_count):
ir.add_edge(idx - 1, o, idx, o)
return ir
# === 测试转换 ===
if __name__ == "__main__":
model = {
"ops": [
{
"name": "conv_1",
"type": "Conv2D",
"inputs": [
{"name": "x", "shape": [1, 3, 224, 224], "dtype": "DT_FLOAT"},
{"name": "w", "shape": [64, 3, 3, 3], "dtype": "DT_FLOAT"},
],
"outputs": [
{"name": "y", "shape": [1, 64, 112, 112], "dtype": "DT_FLOAT"},
],
"attrs": {"stride": [2, 2], "pad": [1, 1, 1, 1], "group": 1},
},
{
"name": "relu_1",
"type": "Relu",
"inputs": [
{"name": "x", "shape": [1, 64, 112, 112], "dtype": "DT_FLOAT"},
],
"outputs": [
{"name": "y", "shape": [1, 64, 112, 112], "dtype": "DT_FLOAT"},
],
"attrs": {},
},
{
"name": "pool_1",
"type": "MaxPool",
"inputs": [
{"name": "x", "shape": [1, 64, 112, 112], "dtype": "DT_FLOAT"},
],
"outputs": [
{"name": "y", "shape": [1, 64, 56, 56], "dtype": "DT_FLOAT"},
],
"attrs": {"kernel_size": [2, 2], "strides": [2, 2]},
},
]
}
ir = Parser.from_framework(model)
ir.dump()
九、总结
metadef作为CANN架构中的基础组件,承担着算子原语定义和图IR统一表示的双重使命。它通过OpRegistrationData、TensorDesc、Shape、AttrValue等基础数据结构,构建了一套与前端框架无关、与底层硬件紧密适配的中间表示体系。
理解metadef的设计思路,有助于开发者:
- 在算子开发时:遵循metadef的原语规范注册算子,确保算子在CANN生态中可被识别和管理。
- 在图编译时:理解metadef IR的结构,正确查询和修改IR以获得最优的编译结果。
- 在跨框架适配时:利用metadef的统一原语层,一次适配即可服务多个前端框架。
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐
所有评论(0)