训练营简介 2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。

报名链接:https://www.hiascend.com/developer/activities/cann20252#cann-camp-2502-intro

前言

在算子开发中,我们经常听到“静态 Shape”和“动态 Shape”之争。

  • 静态 Shape (Static):在编译期就把 Shape 写死(如 constexpr int N = 1024)。

    • 优点:编译器能做极致优化(循环展开、常量折叠)。

    • 缺点:输入变了,算子就废了,必须重编。

  • 动态 Shape (Dynamic):在运行期读取 Shape。

    • 优点:灵活,一个二进制吃遍天下。

    • 缺点:开发难度稍大,需要配合 Host 侧的动态 Tiling。

在 Ascend C 的工业级开发中,二进制泛化是标配。这意味着我们的 Kernel 代码中不能出现硬编码的数字,一切都要依赖从 Host 传来的 TilingData

本期文章将带你打通 Host 与 Device 的动态任督二脉,实现真正的全场景适配。

一、 核心图解:液态金属架构

如果说静态算子是坚硬的“冰块”,只能放入特定形状的容器;那么泛化算子就是“液态金属”,可以随容器(Input Shape)的形状自由流动变化。

二、 泛化开发的“三驾马车”

要实现二进制泛化,我们需要协同三个模块的工作:

  1. InferShape (Host):告诉框架,输出的 Shape 长什么样。

  2. Tiling (Host):在运行时根据具体的 Shape,计算出切分参数。

  3. Kernel (Device):盲目执行,只认 Tiling 传过来的参数。

2.1 第一驾:InferShape 推导

当输入 Shape 变化时,输出 Shape 通常也会变(比如 Element-wise 操作)。我们需要在 op_host/add_custom.cpp 中注册推导逻辑。

// 注册 InferShape 函数
// 这里直接复用 CANN 提供的通用推导函数,它会自动推导 Output = Input
// 如果你的算子有特殊的 Shape 变化(如 Reduce, MatMul),需要自定义
this->SetInferShape(ge::InferShape);

2.2 第二驾:动态 Tiling 计算 (大脑)

这是泛化的核心。Tiling 函数会在每次算子执行前被 Runtime 调用。此时,我们能拿到真实的输入 Shape。

// op_host/add_custom_tiling.cpp

static ge::graphStatus TilingFunc(gert::TilingContext* context) {
    // 1. [关键] 从 Context 获取运行时真实的 Input Shape
    const gert::StorageShape* x_shape = context->GetInputShape(0);
    int64_t total_elements = x_shape->GetStorageShape().GetShapeSize();

    // 2. 获取硬件信息 (UB 大小, 核数)
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
    uint32_t core_num = ascendcPlatform.GetCoreNumAic();
    uint32_t ub_size = 0;
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size);

    // 3. 执行动态切分逻辑
    // (逻辑参考第四期 Tiling 教程,此处省略细节)
    // 重点是:所有的计算都基于 total_elements 和 core_num 变量,而不是常量
    uint32_t tile_num = ...;
    uint32_t tile_len = ...;

    // 4. 填入 TilingData 结构体
    AddCustomTilingData tiling;
    tiling.set_totalLength(total_elements);
    tiling.set_tileNum(tile_num);
    tiling.set_tileLength(tile_len);
    
    // ... 序列化 ...
    return ge::GRAPH_SUCCESS;
}

2.3 第三驾:Kernel 执行 (肌肉)

Kernel 代码不需要做任何逻辑修改,因为它天生就是读取 TilingData 的。

// op_kernel/add_custom.cpp

extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, ..., GM_ADDR tiling) {
    // 1. 反序列化拿到 Host 算好的参数
    GET_TILING_DATA(tiling_data, tiling);
    
    // 2. 直接使用参数,完全不关心这些数字是 1024 还是 2048
    // 只要 Host 算对了,Device 就能跑对
    op.Init(x, ..., tiling_data.totalLength, tiling_data.tileLength);
    op.Process();
}

三、 进阶挑战:非连续 Shape 与 Padding

泛化不仅仅是长度变长变短,还包括处理非连续内存特殊对齐

3.1 应对 stride (非连续内存)

在 PyTorch 中,tensor.transpose() 产生的 Tensor 在内存中是不连续的。

  • 初级泛化:要求用户必须 .contiguous() 再传进来。

  • 高级泛化:TilingData 中传递 stride 参数,Kernel 中使用支持 stride 的 DataCopy 接口(或者自定义搬运逻辑),直接处理非连续内存。

3.2 极端 Shape 处理

如果运行时 Shape 极小(比如 total_length = 1),按照常规逻辑 BlockDim 可能会算出 0,或者 TileLength 对齐后出错。 避坑:在 Host Tiling 代码中必须有保底逻辑

if (total_elements < 32) {
    // 极小 Shape 特殊处理:单核,一次搬完,强制对齐
    context->SetBlockDim(1);
    // ...
}

四、 总结

二进制泛化是 Ascend C 算子迈向成熟产品的标志。

  1. Host 侧:根据 Context 动态计算一切。

  2. Device 侧:信任 TilingData,做无情的执行机器。

  3. 架构:Host 负责“智商”(处理复杂 Shape 逻辑),Device 负责“肌肉”(暴力计算)。

掌握了这一期,你的算子就不再是一个只能跑 Demo 的玩具,而是能够集成到大型神经网络中,扛起千变万化输入数据的通用算子

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐