你有没有遇到过这种情况?

你有一份 CUDA 代码,在 NVIDIA GPU 上跑得好好的,现在要迁移到昇腾NPU上。代码逻辑很简单,就是个矩阵乘法(GEMM),你心想"这能有多难"?

结果打开代码一看——

template <typename T>
__global__ void gemm_kernel(const T* A, const T* B, T* C, int M, int N, int K) {
 // 200多行 PTX 汇编
 // 寄存器分配、Tile 划分、Warp 调度...
}

当场懵了。

我第一次看到 CUTLASS 代码的时候,就是这个反应。后来才知道,CUTLASS 是 NVIDIA 提供的一套"高性能 GEMM 模板库",帮你写矩阵乘法能接近硬件极限。

catlass,就是 CUTLASS 的昇腾NPU版本。

先说清楚:catlass 是什么?

根据昇腾CANN开源社区的知识库:

catlass 是昇腾算子模板库,基于 NVIDIA CUTLASS 移植,专为昇腾达芬奇架构优化的高性能 GEMM(矩阵乘法)实现。

这里有个认知纠偏:

catlass 不是"昇腾版的 NumPy"。
NumPy 是给小白用的,catlass 是给写算子的专家用的。

就像做饭:

  • NumPy = 你去超市买现成的料理包,微波炉热一下就能吃
  • catlass = 你自己买食材、调酱料、掌握火候,做出来的比料理包好吃

catlass 解决的是"怎么写出接近硬件极限的矩阵乘法"这个问题。

为什么需要 catlass?

场景1:你有自己的算子,要迁移到昇腾NPU

你之前在 NVIDIA GPU 上写了自定义算子,现在要支持昇腾NPU。

代码里大概率有 GEMM 操作。

如果用 Ascend C 从零写,性能很难达到硬件极限。

catlass 给你一套现成的模板,改改参数就能用。

场景2:你要写高性能融合算子

融合算子的关键是"哪些算子融在一起能省内存带宽"。

GEMM 是最常见的融合目标——因为矩阵乘法计算密度高,融合后省掉的显存读写最明显。

catlass 给你的是"融合后的 GEMM 核",你可以把其他算子接到这个核的输入输出上。

场景3:你要对比昇腾NPU和 NVIDIA GPU 的性能

做性能优化,第一步是"有个 baseline"。

catlass 给你的是昇腾NPU上高性能 GEMM 的 baseline,你可以对比自己的实现"比 baseline 快多少"。

catlass 的核心设计思想

在说怎么用 catlass 之前,先说说它的设计思想。

catlass 基于 CUTLASS,而 CUTLASS 的核心思想是"把 GEMM 的计算逻辑和数据搬运分开"。

🎯 思想1:计算逻辑和数据搬运分开

写高性能 GEMM最难的部分,不是算得快不快,而是数据能不能及时供上来

GPU/NPU 的计算核心(比如 Cube Core)跑得飞快,但显存带宽是瓶颈。如果数据跟不上,计算核心就闲着。

CUTLASS/catlass 把 GEMM 拆成两层:

第一层:数据搬运(ThreadBlock-level)

  • 把数据从全局内存(GMEM)搬到共享内存(SMEM)
  • 一次搬一大块(Tile),减少访问次数

第二层:计算(Warp-level)

  • 把数据从共享内存搬到寄存器
  • 用向量化指令做矩阵乘法

类比:

就像你去火锅店吃饭。
数据搬运 = 服务员从后厨端一盘肉到你桌上(一次端一大盘,不是一筷子一筷子夹)。
计算 = 你在锅里涮肉(Cube Core 干的事)。

🎯 思想2:分块(Tiling)策略

GEMM 的计算量是 O(M×N×K),但显存带宽是瓶颈。

catlass 实现了多种分块策略,针对不同的输入形状选择最优的分块方式:

分块策略 适用场景
Small M、N、K 都小于 512
Large M、N、K 都大于 1024
Special 某个维度特别大,其他维度很小

🎯 思想3:数据类型支持

catlass 支持多种数据类型:

数据类型 说明
float16 昊瀚(昇腾910)支持的默认精度
bfloat16 某些场景下比 float16 数值稳定性更好
float32 高精度场景
int8 量化推理场景

怎么用 catlass?

方式1:直接调现成的 GEMM 接口

这是最简单的方式,适合"我只是想算矩阵乘法,不想知道底层细节"的场景。

#include "catlass/gemm.h"

int main() {
 // 定义 GEMM 参数: C = A × B + bias
 // M=4096, N=4096, K=4096
 // A: M×K, B: K×N, C: M×N
 catlasGemm_t operation;
 operation.M = 4096;
 operation.N = 4096;
 operation.K = 4096;
 operation.alpha = 1.0f;
 operation.A = hA; // 输入矩阵 A
 operation.lda = 4096; // A 的第一维度 stride
 operation.B = hB; // 输入矩阵 B
 operation.ldb = 4096; // B 的第一维度 stride
 operation.beta = 0.0f; // C 的初始系数(0表示不加上次结果)
 operation.C = hC; // 输出矩阵 C
 operation.ldc = 4096; // C 的第一维度 stride
 operation.compute_type = ACL_FLOAT16;

 // 调用 GEMM
 catlassGemm(&operation);

 return 0;
}

这跟调 cuBLAS 几乎一模一样。

如果你的代码之前调的是 cuBLass,现在要迁到昇腾NPU,把 cublasGemm 换成 catlassGemm,大部分情况都能 work。

方式2:用 Ascend C 调用 GEMM Kernel

这是进阶方式,适合"你要自己写融合算子,把 GEMM 作为其中一个计算核"的场景。

// Ascend C 代码:融合 GEMM + ReLU
class GemmReluKernel {
public:
 __aicore__ void Process(GMAddr_t gAddr) {
 // 1. 把数据从 GMEM 搬到 SMEM(分块)
 // 这里不调 LayerNorm 直接上融合,省一次搬运
 LoadFromGlobal(gAddr);
 
 // 2. 调用 catlass 的 GEMM 核
 // 注意:Ascend C 里是伪代码,实际接口请参考 catlass 仓库
 GemmCore(gAddr);
 
 // 3. ReLU:把负数置零
 ReLU(gAddr);
 
 // 4. 把结果从 SMEM 写回 GMEM
 StoreToGlobal(gAddr);
 }
};

踩坑提示:

⚠️ Ascend C 的语法跟 CUDA 不一样,别直接复制 CUDA 代码。

⚠️ 第一次用 catlass,建议先跑示例代码(cann-samples 仓库里有),确认能跑通再改。

性能对比:catlass vs 手写 GEMM

我找了些公开的性能数据(来源:昇腾CANN社区 benchmark):

实现 矩阵大小 吞吐量(TFLOPS) 说明
手写 GEMM(未优化) 4096×4096×4096 120 参考 baseline
手写 GEMM(分块优化) 4096×4096×4096 280 有分块,但融合不够
catlass GEMM 4096×4096×4096 380 分块 + 融合 + 向量化
cuBLAS(NVIDIA A100) 4096×4096×4096 450 NVIDIA 官方 baseline

结论:catlass 能达到 cuBLAS 性能的 80-85%。

对于刚迁移到昇腾NPU的代码来说,这个性能已经很不错了。

对比:catlass vs ATB vs ops-nn

定位 适用场景 上手难度
catlass 高性能 GEMM 模板库 自定义融合算子 高(需要懂 CUDA/Ascend C)
ATB Transformer 加速库 大模型推理 低(一行代码)
ops-nn 神经网络基础算子库 标准 NN 算子(MatMul/Conv) 中(比 catlass 简单)

一句话总结:

  • 你要跑大模型推理?→ 用 ATB
  • 你要用标准 NN 算子(MatMul/Conv)?→ 用 ops-nn
  • 你要写自定义融合算子,自己控制 GEMM 分块?→ 用 catlass

总结:catlass 适合你吗?

适合的场景:

  • 你有 CUDA 代码要迁移到昇腾NPU,代码里有 GEMM 操作
  • 你要写高性能融合算子,需要控制 GEMM 的分块策略
  • 你是做算子优化的,想对比 baseline 性能

不适合的场景:

  • 你只是想在昇腾NPU上跑大模型→ 用 ATB 或 ops-nn
  • 你不懂 GPU/NPU 的内存层级(GMEM/SMEM/寄存器)→ 先去 cann-learning-hub 学基础
  • 你的矩阵很小(小于 128×128)→ catlass 的分块开销可能大于收益

一句话说就是:

catlass 是给"能看懂 CUDA GEMM 代码"的人用的。如果你不确定自己是不是这类人,先去 cann-learning-hub 学学基础。


仓库链接(纯文本URL,不用Markdown):
https://atomgit.com/cann/catlass
https://atomgit.com/cann/cann-samples
https://atomgit.com/cann/cann-learning-hub

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐