在这里插入图片描述

16.1 什么是模板参数算子

16.1.1 模板参数的概念

在之前的算子实现中,数据类型、Tiling策略等都是硬编码的。模板参数算子允许在编译时或运行时通过模板参数来配置算子的行为,使同一个算子可以支持多种数据类型、格式、Tiling策略等。

传统算子

  • 数据类型固定(如float16)
  • Tiling策略固定(如TILE_NUM=8)
  • 需要为每种组合写不同的算子

模板参数算子

  • 数据类型可配置(支持float16、float32等)
  • Tiling策略可配置(支持不同的TILE_NUM)
  • 一个算子可以处理多种组合

16.1.2 模板参数的优势

代码复用:一个算子实现可以处理多种场景,减少代码重复。

灵活配置:可以根据输入动态选择最优的配置。

性能优化:可以为不同场景选择最优的Tiling策略。

易于维护:修改逻辑只需要改一处,不需要改多个算子。

16.1.3 模板参数的类型

模板参数可以包括:

  • 数据类型(DTYPE):float16、float32等
  • 数据格式(FORMAT):ND、FRACTAL_NZ等
  • 整数参数(UINT):TILE_NUM、BLOCK_DIM等
  • 布尔参数(BOOL):IS_SPLIT、是否需要workspace等

16.2 模板参数定义

16.2.1 Tiling Key定义

在tiling_key_add_custom.h中定义模板参数:

#include "ascendc/host_api/tiling/template_argument.h"

// 定义数据类型枚举值
#define ADD_TPL_FP16 10
#define ADD_TPL_FP32 20

// 定义模板参数
ASCENDC_TPL_ARGS_DECL(AddTemplateCustom,  // 算子唯一标识
    // 数据类型参数:支持FP16和FP32
    ASCENDC_TPL_DTYPE_DECL(D_T_X, ADD_TPL_FP16, ADD_TPL_FP32),
    ASCENDC_TPL_DTYPE_DECL(D_T_Y, ADD_TPL_FP16, ADD_TPL_FP32),
    ASCENDC_TPL_DTYPE_DECL(D_T_Z, ADD_TPL_FP16, ADD_TPL_FP32),
    
    // 整数参数:TILE_NUM,使用8比特位宽,混合模式
    // 范围:{0, 1, 2, 3, 4, 5},穷举:{10, 12, 13, 9, 8}
    ASCENDC_TPL_UINT_DECL(TILE_NUM, ASCENDC_TPL_8_BW, 
                         ASCENDC_TPL_UI_MIX, 
                         2, 0, 2, 3, 5,  // 2组范围
                         10, 12, 13, 9, 8),  // 穷举值
    
    // 布尔参数:IS_SPLIT,1比特位宽
    ASCENDC_TPL_BOOL_DECL(IS_SPLIT, 0, 1),
);

ASCENDC_TPL_ARGS_DECL:定义算子的所有模板参数。

ASCENDC_TPL_DTYPE_DECL:定义数据类型参数,支持多个枚举值。

ASCENDC_TPL_UINT_DECL:定义整数参数,支持范围模式和穷举模式。

ASCENDC_TPL_BOOL_DECL:定义布尔参数,支持0和1。

16.2.2 模板参数组合

定义模板参数的有效组合:

ASCENDC_TPL_SEL(
    // 组合1:FP16类型
    ASCENDC_TPL_ARGS_SEL(
        ASCENDC_TPL_DTYPE_SEL(D_T_X, ADD_TPL_FP16),
        ASCENDC_TPL_DTYPE_SEL(D_T_Y, ADD_TPL_FP16),
        ASCENDC_TPL_DTYPE_SEL(D_T_Z, ADD_TPL_FP16),
        ASCENDC_TPL_UINT_SEL(TILE_NUM, ASCENDC_TPL_UI_LIST, 1, 8),
        ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1),
        ASCENDC_TPL_TILING_STRUCT_SEL(optiling::TilingDataFp16)  // 对应的Tiling结构体
    ),
    
    // 组合2:FP32类型
    ASCENDC_TPL_ARGS_SEL(
        ASCENDC_TPL_DTYPE_SEL(D_T_X, ADD_TPL_FP32),
        ASCENDC_TPL_DTYPE_SEL(D_T_Y, ADD_TPL_FP32),
        ASCENDC_TPL_DTYPE_SEL(D_T_Z, ADD_TPL_FP32),
        ASCENDC_TPL_UINT_SEL(TILE_NUM, ASCENDC_TPL_UI_LIST, 1, 8),
        ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1),
        ASCENDC_TPL_TILING_STRUCT_SEL(optiling::TilingDataFp)  // 对应的Tiling结构体
    )
);

ASCENDC_TPL_SEL:定义所有有效的模板参数组合。

ASCENDC_TPL_ARGS_SEL:定义一个具体的参数组合。

ASCENDC_TPL_TILING_STRUCT_SEL:指定该组合对应的Tiling结构体。

16.2.3 Tiling结构体定义

为不同的模板参数组合定义不同的Tiling结构体:

namespace optiling {
// 默认Tiling结构体
class TilingData {
public:
    uint32_t totalLength;
};

// FP32类型的Tiling结构体
class TilingDataFp {
public:
    uint32_t totalLength;
};

// FP16类型的Tiling结构体
class TilingDataFp16 {
public:
    uint32_t totalLength;
};
} // namespace optiling

不同的模板参数组合可以使用不同的Tiling结构体,这样可以支持不同的Tiling参数。


16.3 Host端实现

16.3.1 Tiling函数

在Tiling函数中,根据输入动态生成TilingKey:

namespace optiling {
const uint32_t BLOCK_DIM = 8;
const uint32_t DEFAULT_TILE_NUM = 8;
constexpr int MIN_LENGTH_FOR_SPLIT = 2048;

static ge::graphStatus TilingFunc(gert::TilingContext *context)
{
    // 1. 获取输入信息
    uint32_t totalLength = context->GetInputShape(0)->GetOriginShape().GetShapeSize();
    ge::DataType dataTypeX = context->GetInputDesc(0)->GetDataType();
    ge::DataType dataTypeY = context->GetInputDesc(1)->GetDataType();
    ge::DataType dataTypeZ = context->GetOutputDesc(0)->GetDataType();
    
    // 2. 根据输入数据类型确定模板参数
    uint32_t D_T_X = ADD_TPL_FP32, D_T_Y = ADD_TPL_FP32, D_T_Z = ADD_TPL_FP32;
    uint32_t TILE_NUM = 1, IS_SPLIT = 0;
    
    if (dataTypeX == ge::DataType::DT_FLOAT) {
        D_T_X = ADD_TPL_FP32;
    } else if (dataTypeX == ge::DataType::DT_FLOAT16) {
        D_T_X = ADD_TPL_FP16;
    }
    
    if (dataTypeY == ge::DataType::DT_FLOAT) {
        D_T_Y = ADD_TPL_FP32;
    } else if (dataTypeY == ge::DataType::DT_FLOAT16) {
        D_T_Y = ADD_TPL_FP16;
    }
    
    if (dataTypeZ == ge::DataType::DT_FLOAT) {
        D_T_Z = ADD_TPL_FP32;
    } else if (dataTypeZ == ge::DataType::DT_FLOAT16) {
        D_T_Z = ADD_TPL_FP16;
    }
    
    // 3. 根据数据长度确定Tiling策略
    if (totalLength < MIN_LENGTH_FOR_SPLIT) {
        IS_SPLIT = 0;
        TILE_NUM = 1;
    } else {
        IS_SPLIT = 1;
        TILE_NUM = DEFAULT_TILE_NUM;
    }
    
    // 4. 根据模板参数选择Tiling结构体并设置值
    if (D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32) {
        TilingDataFp *tiling = context->GetTilingData<TilingDataFp>();
        tiling->totalLength = totalLength;
    } else if (D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16) {
        TilingDataFp16 *tiling = context->GetTilingData<TilingDataFp16>();
        tiling->totalLength = totalLength;
    }
    
    // 5. 生成TilingKey并设置
    context->SetBlockDim(BLOCK_DIM);
    const uint64_t tilingKey = GET_TPL_TILING_KEY(D_T_X, D_T_Y, D_T_Z, TILE_NUM, IS_SPLIT);
    context->SetTilingKey(tilingKey);
    
    size_t *currentWorkspace = context->GetWorkspaceSizes(1);
    currentWorkspace[0] = 0;
    return ge::GRAPH_SUCCESS;
}
} // namespace optiling

关键步骤

  1. 获取输入信息(形状、数据类型)
  2. 根据输入确定模板参数值
  3. 根据模板参数选择对应的Tiling结构体
  4. 生成TilingKey并设置

GET_TPL_TILING_KEY:根据模板参数生成TilingKey,用于在编译时选择对应的Kernel代码路径。

16.3.2 算子注册

算子注册时支持多种数据类型:

namespace ops {
class AddCustom : public OpDef {
public:
    explicit AddCustom(const char *name) : OpDef(name)
    {
        // 支持多种数据类型
        this->Input("x")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT16, ge::DT_FLOAT})  // 支持FP16和FP32
            .Format({ge::FORMAT_ND, ge::FORMAT_ND});
        
        this->Input("y")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT16, ge::DT_FLOAT})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND});
        
        this->Output("z")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT16, ge::DT_FLOAT})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND});
        
        this->SetInferShape(ge::InferShape)
            .SetInferDataType(ge::InferDataType);
        this->AICore()
            .SetTiling(optiling::TilingFunc)
            .AddConfig("ascend910")
            .AddConfig("ascend310p")
            .AddConfig("ascend310b")
            .AddConfig("ascend910b");
    }
};

OP_ADD(AddCustom);
} // namespace ops

16.4 Kernel端实现

16.4.1 模板Kernel函数

Kernel函数使用模板参数:

template<int D_T_X, int D_T_Y, int D_T_Z, int TILE_NUM, int IS_SPLIT>
__global__ __aicore__ void add_custom(
    GM_ADDR x, GM_ADDR y, GM_ADDR z, 
    GM_ADDR workspace, GM_ADDR tiling)
{
    // 1. 注册Tiling结构体
    REGISTER_TILING_DEFAULT(optiling::TilingData);
    
    // 2. 根据模板参数注册对应的Tiling结构体
    REGISTER_TILING_FOR_TILINGKEY(
        "D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32", 
        optiling::TilingDataFp);
    
    REGISTER_TILING_FOR_TILINGKEY(
        "D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16", 
        optiling::TilingDataFp16);
    
    // 3. 根据模板参数选择代码路径
    if (D_T_X == ADD_TPL_FP32 && D_T_Y == ADD_TPL_FP32 && D_T_Z == ADD_TPL_FP32) {
        GET_TILING_DATA_WITH_STRUCT(optiling::TilingDataFp, tiling_data, tiling);
        KernelAdd<float, float, float> op;
        op.Init(x, y, z, tiling_data.totalLength, TILE_NUM);
        op.Process1();  // FP32使用Process1
    } else if (D_T_X == ADD_TPL_FP16 && D_T_Y == ADD_TPL_FP16 && D_T_Z == ADD_TPL_FP16) {
        GET_TILING_DATA_WITH_STRUCT(optiling::TilingDataFp16, tiling_data, tiling);
        KernelAdd<half, half, half> op;
        if (IS_SPLIT == 0) {
            op.Init(x, y, z, tiling_data.totalLength, TILE_NUM);
            op.Process1();  // 不切分
        } else if (IS_SPLIT == 1) {
            op.Init(x, y, z, tiling_data.totalLength, TILE_NUM);
            op.Process2();  // 切分
        }
    }
}

关键点

  • Kernel函数是模板函数,模板参数对应TilingKey中的参数
  • 根据模板参数选择不同的代码路径
  • 使用GET_TILING_DATA_WITH_STRUCT获取对应的Tiling数据

16.4.2 模板Kernel类

Kernel类也是模板类,支持不同的数据类型:

template<class dtypeX, class dtypeY, class dtypeZ>
class KernelAdd {
public:
    __aicore__ inline KernelAdd() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, 
                                uint32_t totalLength, uint32_t tileNum)
    {
        this->blockLength = totalLength / AscendC::GetBlockNum();
        this->tileNum = tileNum;
        uint32_t tile_type = 1;
        
        // 根据TILE_NUM选择不同的tileLength计算方式
        if (tileNum == tile_type) {
            this->tileLength = totalLength;
        } else {
            this->tileLength = this->blockLength / tileNum / BUFFER_NUM;
        }
        
        // 使用模板类型
        xGm.SetGlobalBuffer((__gm__ dtypeX *)x + 
                           this->blockLength * AscendC::GetBlockIdx(), 
                           this->blockLength);
        yGm.SetGlobalBuffer((__gm__ dtypeY *)y + 
                           this->blockLength * AscendC::GetBlockIdx(), 
                           this->blockLength);
        zGm.SetGlobalBuffer((__gm__ dtypeZ *)z + 
                           this->blockLength * AscendC::GetBlockIdx(), 
                           this->blockLength);
        
        pipe.InitBuffer(inQueueX, BUFFER_NUM, 
                       this->tileLength * sizeof(dtypeX));
        pipe.InitBuffer(inQueueY, BUFFER_NUM, 
                       this->tileLength * sizeof(dtypeY));
        pipe.InitBuffer(outQueueZ, BUFFER_NUM, 
                       this->tileLength * sizeof(dtypeZ));
    }
    
    // Process1:不切分,只处理一次
    __aicore__ inline void Process1()
    {
        CopyIn(0);
        Compute(0);
        CopyOut(0);
    }
    
    // Process2:切分,循环处理
    __aicore__ inline void Process2()
    {
        int32_t loopCount = this->tileNum * BUFFER_NUM;
        for (int32_t i = 0; i < loopCount; i++) {
            CopyIn(i);
            Compute(i);
            CopyOut(i);
        }
    }
    
private:
    // CopyIn、Compute、CopyOut使用模板类型
    __aicore__ inline void CopyIn(int32_t progress)
    {
        AscendC::LocalTensor<dtypeX> xLocal = inQueueX.AllocTensor<dtypeX>();
        AscendC::LocalTensor<dtypeY> yLocal = inQueueY.AllocTensor<dtypeY>();
        AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], this->tileLength);
        AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], this->tileLength);
        inQueueX.EnQue(xLocal);
        inQueueY.EnQue(yLocal);
    }
    
    __aicore__ inline void Compute(int32_t progress)
    {
        AscendC::LocalTensor<dtypeX> xLocal = inQueueX.DeQue<dtypeX>();
        AscendC::LocalTensor<dtypeY> yLocal = inQueueY.DeQue<dtypeY>();
        AscendC::LocalTensor<dtypeZ> zLocal = outQueueZ.AllocTensor<dtypeZ>();
        AscendC::Add(zLocal, xLocal, yLocal, this->tileLength);
        outQueueZ.EnQue<dtypeZ>(zLocal);
        inQueueX.FreeTensor(xLocal);
        inQueueY.FreeTensor(yLocal);
    }
    
    __aicore__ inline void CopyOut(int32_t progress)
    {
        AscendC::LocalTensor<dtypeZ> zLocal = outQueueZ.DeQue<dtypeZ>();
        AscendC::DataCopy(zGm[progress * this->tileLength], zLocal, this->tileLength);
        outQueueZ.FreeTensor(zLocal);
    }
    
private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueX, inQueueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueueZ;
    AscendC::GlobalTensor<dtypeX> xGm;
    AscendC::GlobalTensor<dtypeY> yGm;
    AscendC::GlobalTensor<dtypeZ> zGm;
    uint32_t blockLength;
    uint32_t tileNum;
    uint32_t tileLength;
};

16.5 TilingKey机制

16.5.1 TilingKey的作用

TilingKey是一个64位的整数,编码了所有模板参数的值。框架使用TilingKey来:

  • 在编译时选择对应的Kernel代码路径
  • 在运行时匹配正确的Kernel实例
  • 优化代码生成,只编译需要的代码路径

16.5.2 TilingKey的生成

在Host端的Tiling函数中生成TilingKey:

const uint64_t tilingKey = GET_TPL_TILING_KEY(D_T_X, D_T_Y, D_T_Z, TILE_NUM, IS_SPLIT);
context->SetTilingKey(tilingKey);

GET_TPL_TILING_KEY宏根据模板参数的值生成TilingKey,每个参数占用一定的比特位。

16.5.3 TilingKey的匹配

在Kernel端,框架根据TilingKey选择对应的Kernel实例。模板参数的值在编译时确定,不同的TilingKey会编译出不同的Kernel代码。


16.6 模板参数的优势

16.6.1 代码复用

一个Kernel实现可以处理多种数据类型:

// 传统方式:需要为每种类型写不同的Kernel
void add_custom_fp16(...) { KernelAdd<half, half, half> op; ... }
void add_custom_fp32(...) { KernelAdd<float, float, float> op; ... }

// 模板方式:一个Kernel处理所有类型
template<int D_T_X, int D_T_Y, int D_T_Z, ...>
void add_custom(...) { 
    if (D_T_X == ADD_TPL_FP16) { KernelAdd<half, half, half> op; ... }
    else if (D_T_X == ADD_TPL_FP32) { KernelAdd<float, float, float> op; ... }
}

16.6.2 灵活配置

可以根据输入动态选择最优配置:

// 根据数据长度选择是否切分
if (totalLength < MIN_LENGTH_FOR_SPLIT) {
    IS_SPLIT = 0;  // 小数据不切分
    TILE_NUM = 1;
} else {
    IS_SPLIT = 1;  // 大数据切分
    TILE_NUM = DEFAULT_TILE_NUM;
}

16.6.3 性能优化

可以为不同场景选择最优策略:

// FP32使用Process1(简单路径)
if (D_T_X == ADD_TPL_FP32) {
    op.Process1();
}
// FP16根据IS_SPLIT选择
else if (D_T_X == ADD_TPL_FP16) {
    if (IS_SPLIT == 0) {
        op.Process1();  // 不切分
    } else {
        op.Process2();  // 切分
    }
}

16.7 与普通算子的对比

16.7.1 代码复杂度

普通算子

  • 代码简单,但需要为每种组合写不同的算子
  • 代码重复多,维护困难

模板参数算子

  • 代码稍复杂,但一个算子处理多种组合
  • 代码复用高,维护方便

16.7.2 灵活性

普通算子

  • 灵活性低,配置固定
  • 需要修改代码才能支持新配置

模板参数算子

  • 灵活性高,配置可动态选择
  • 只需修改模板参数定义即可支持新配置

16.7.3 性能

普通算子

  • 编译时优化简单
  • 但代码体积大(多个算子)

模板参数算子

  • 编译时根据TilingKey优化
  • 只编译需要的代码路径,代码体积小

16.8 关键注意事项

16.8.1 Tiling结构体匹配

不同的模板参数组合必须使用对应的Tiling结构体:

// Host端
if (D_T_X == ADD_TPL_FP32) {
    TilingDataFp *tiling = context->GetTilingData<TilingDataFp>();
    // ...
}

// Kernel端
REGISTER_TILING_FOR_TILINGKEY(
    "D_T_X == ADD_TPL_FP32 && ...", 
    optiling::TilingDataFp);
GET_TILING_DATA_WITH_STRUCT(optiling::TilingDataFp, tiling_data, tiling);

如果Tiling结构体不匹配,会导致OOM(Out of Memory)问题。

16.8.2 TilingKey编码

TilingKey的编码是固定的,一旦定义后不能随意修改:

  • 修改参数定义会影响编码位置
  • 可能导致TilingKey不兼容
  • 需要仔细规划参数的位置和位宽

16.8.3 模板参数组合

必须确保模板参数组合在ASCENDC_TPL_SEL中定义,否则编译会失败。

16.8.4 代码路径选择

在Kernel中,必须根据模板参数正确选择代码路径,确保逻辑正确。


16.9 扩展:添加新的模板参数

16.9.1 添加新的数据类型

如果要支持新的数据类型(如int8),需要:

  1. 在tiling_key_add_custom.h中定义枚举值:
#define ADD_TPL_INT8 30
  1. 在模板参数定义中添加:
ASCENDC_TPL_DTYPE_DECL(D_T_X, ADD_TPL_FP16, ADD_TPL_FP32, ADD_TPL_INT8),
  1. 在模板参数组合中添加:
ASCENDC_TPL_ARGS_SEL(
    ASCENDC_TPL_DTYPE_SEL(D_T_X, ADD_TPL_INT8),
    // ...
),
  1. 在Kernel中添加代码路径:
else if (D_T_X == ADD_TPL_INT8) {
    KernelAdd<int8_t, int8_t, int8_t> op;
    // ...
}

16.9.2 添加新的整数参数

如果要添加新的整数参数(如BLOCK_DIM),需要:

  1. 在模板参数定义中添加:
ASCENDC_TPL_UINT_DECL(BLOCK_DIM, ASCENDC_TPL_4_BW, ASCENDC_TPL_UI_LIST, 4, 8, 16),
  1. 在Tiling函数中设置值:
uint32_t BLOCK_DIM = 8;  // 根据情况确定
const uint64_t tilingKey = GET_TPL_TILING_KEY(..., BLOCK_DIM);
  1. 在Kernel中使用:
template<..., int BLOCK_DIM>
void add_custom(...) {
    // 使用BLOCK_DIM
}

16.10 适用场景

16.10.1 适合使用模板参数的情况

需要支持多种数据类型:算子需要支持float16、float32等多种类型。

需要灵活的Tiling策略:根据数据大小、形状等选择不同的Tiling策略。

需要代码复用:多个相似的算子可以合并为一个模板参数算子。

需要性能优化:为不同场景选择最优的代码路径。

16.10.2 不适合使用模板参数的情况

简单的固定配置算子:如果算子配置固定,不需要模板参数。

模板参数组合过多:如果组合数量太多,会导致编译时间过长。

逻辑差异很大:如果不同配置的逻辑差异很大,分开实现可能更清晰。

2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。

报名链接:https://www.hiascend.com/developer/activities/cann20252

社区地址:https://www.hiascend.com/developer

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐