昇腾NPU上的FlashAttention藏在哪?ops-transformer仓库全景图
刚接触昇腾CANN生态的时候,光是找FlashAttention算子在哪就花了不少时间。官方文档按功能模块划分,仓库按层级划分,两套逻辑对不上号——文档里写的是"大模型算子优化",仓库里是一个叫ops-transformer的目录。这种对应关系不搞清楚,改代码都不知道去哪改。ops-transformer是昇腾CANN大模型算子的主阵地。FlashAttention、RMSNorm、RoPE、Sw
刚接触昇腾CANN生态的时候,光是找FlashAttention算子在哪就花了不少时间。官方文档按功能模块划分,仓库按层级划分,两套逻辑对不上号——文档里写的是"大模型算子优化",仓库里是一个叫ops-transformer的目录。这种对应关系不搞清楚,改代码都不知道去哪改。
ops-transformer是昇腾CANN大模型算子的主阵地。FlashAttention、RMSNorm、RoPE、SwiGLU——大模型推理训练用到的核心算子,全在这个仓库里。
仓库在CANN生态里的位置
昇腾CANN的软件栈从下到上分四层:
硬件层:Ascend 910 NPU
↑
驱动与Runtime层:CANN基础组件(固件、CCE编译器、调度器)
↑
算子层:ops-transformer / catlass / ops-ascendc
↑
框架适配层:torch_npu / msModelZoo / ge
ops-transformer在算子层,偏"大模型"方向。通用算子(卷积、池化、BatchNorm等)在ops-ascendc,不在ops-transformer里。这两个仓库的边界要记清楚,不然在ops-transformer里搜"conv2d"永远搜不到。
ops-transformer跟catlass是上下游关系。catlass提供分块矩阵乘、在线softmax等基础操作的模板,ops-transformer的FlashAttention依赖catlass的模板实现具体算法。开发新算子的时候:catlass提供积木,ops-transformer负责搭房子。
目录结构:FlashAttention藏在哪
git clone https://atomgit.com/cann/ops-transformer.git
cd ops-transformer
find . -type f -name "*.cc" -o -name "*.h" | head -30
顶层结构:
ops-transformer/
├── opkernel/ # 算子内核实现(改FlashAttention核心逻辑的地方)
│ ├── flash_attention/
│ │ ├── flash_attention_score.cc # 前向:分块+在线softmax
│ │ ├── flash_attention_score_grad.cc # 反向:重计算
│ │ └── flash_attention_score_tiling.cc # 分块策略:UB容量、块大小
│ ├── rms_norm/
│ │ └── rms_norm.cc
│ └── rope/
│ └── rope.cc
├── opplugin/ # 算子注册:GE怎么找到FlashAttention
│ └── flash_attention/
│ └── flash_attention_op.cc
├── inc/ # 公共头文件
│ └── ops_kernel.h
├── scripts/ # 编译脚本
└── cmake/
└── FindCANN.cmake
三个核心文件各有分工:tiling.cc算分块参数,score.cc写计算逻辑,score_grad.cc写反向传播。
tiling.cc:分块大小不是随便定的
FlashAttention的核心思想是分块计算——把Q、K、V切成小块,每次只拿一小块算,算完就扔中间结果。分块大小怎么定?这是tiling.cc要做的事。
昇腾NPU有个硬约束:Unified Buffer约256KB。UB是达芬奇架构上的高速存储,计算单元直接访问,数据必须先从全局内存(HBM)搬到UB才能用。一次计算需要同时放得下Q块、K块、V块、输出块、softmax中间结果,分块太大直接溢出。
// ops-transformer/opkernel/flash_attention/flash_attention_score_tiling.cc
// 简化分块策略(伪代码)
struct TilingParam {
uint32_t block_m; // Q块在seq维的大小
uint32_t block_n; // K/V块在seq维的大小
uint32_t block_k; // head_dim
};
TilingParam CalcTiling(uint32_t seq_len, uint32_t head_dim) {
// 昇腾910的UB约256KB
const uint32_t kUBSize = 256 * 1024; // 字节
const uint32_t dtype_bytes = 2; // FP16 = 2字节
// 单块需要的空间估算:
// Q块(block_m × head_dim) + K块(block_n × head_dim) +
// V块(block_n × head_dim) + 输出块(block_m × head_dim) +
// softmax中间结果(block_m × block_n)
uint32_t elements_per_block = kUBSize / (5 * dtype_bytes);
// 反推seq维的块大小,向下取到16的倍数(对齐约束)
uint32_t block_seq = elements_per_block / head_dim;
block_seq = (block_seq / 16) * 16;
block_seq = std::min(block_seq, seq_len);
return {block_seq, block_seq, head_dim};
}
分块策略不是越大越好。块大了UB装不下,块小了循环次数多、算子调度开销大。tiling.cc的职责就是在硬件约束下找最优的块大小。昇腾910的UB容量、128字节对齐要求、Cube计算单元的tile形状——这些硬件参数全部要反映在分块逻辑里。
score.cc:前向计算的核心
tiling.cc算出了块大小,score.cc负责按这个分块做计算。核心逻辑就是论文里的在线softmax:
// flash_attention_score.cc 核心循环(伪代码)
for (uint32_t qi = 0; qi < num_q_blocks; qi++) {
// 加载一块Q到UB
LoadTensor(cur_q, q_base, qi * block_m);
// 初始化softmax累加器
float row_max = -1e9f;
float row_sum = 0.0f;
float* acc_out = ub_buffer;
for (uint32_t ki = 0; ki < num_kv_blocks; ki++) {
// 加载K/V块到UB(与计算流水化)
LoadTensor(cur_k, k_base, ki * block_n);
LoadTensor(cur_v, v_base, ki * block_n);
// 局部注意力分数:Q × K^T
// 达芬奇Cube单元做矩阵乘,高吞吐
Gemm(cur_q, cur_k.T, local_scores, block_m, block_n, head_dim);
Scale(local_scores, 1.0f / sqrtf(head_dim));
// causal mask:达芬奇上用块级跳过,不算下三角矩阵
// 如果整块都在mask外,直接跳过,省掉整块计算量
if (IsBlockOutsideCausal(qi, ki, block_m, block_n)) {
continue;
}
// 在线softmax更新
float local_max = ReduceMax(local_scores);
float new_max = fmax(row_max, local_max);
// 关键:缩放之前的累加结果
// exp(row_max - new_max)保证了数学等价性
float correction = expf(row_max - new_max);
row_sum *= correction;
Scale(acc_out, correction);
// 加上当前块的贡献
float local_sum = ReduceSum(Exp(local_scores - new_max));
row_sum += local_sum;
Accumulate(acc_out, Exp(local_scores - new_max), cur_v);
row_max = new_max;
}
// 归一化输出
Scale(acc_out, 1.0f / row_sum);
StoreTensor(out_base, qi * block_m, acc_out);
}
昇腾NPU上有几个值得注意的细节:
causal mask用块级跳过。标准的因果mask是算出一个完整的下三角矩阵再逐元素乘。FlashAttention天然适合块级跳过——如果整块K/V都在当前Q的因果范围之外,直接continue跳过这一块,不浪费计算资源。
数据搬运和计算流水化。加载下一块K/V的同时,当前块的矩阵乘和softmax计算同步进行。达芬奇架构的DMA引擎和Cube计算单元独立运作,overlap得好,搬运时间可以被计算时间掩盖。
数值稳定性。FP16下exp(x)在x>88.7时溢出。每次更新最大值后用exp(row_max - new_max)重新缩放之前的累加结果,把所有分数拉到同一个尺度——这行代码是整个在线softmax数值稳定的关键。
score_grad.cc:反向传播怎么重计算
标准attention前向时存了scores和attn两个N×N矩阵,反向时直接用。FlashAttention不存中间结果,反向时重新算一遍——这就是重计算(recomputation)。
// flash_attention_score_grad.cc 反向逻辑(伪代码)
void FlashAttentionGrad(
Tensor grad_out, // 输出梯度
Tensor q, k, v, // 前向输入
Tensor out, # 前向输出(必须存)
Tensor grad_q, grad_k, grad_v // 输出梯度
) {
// 前向时存的两个东西:out和row_max/row_sum
// 没有存scores和attn,反向时重新算
// 第1步:算输出对out的梯度
// dP = grad_out × V^T
for each block:
dP_block = Gemm(grad_out_block, v_block.T);
// softmax反向:dS = P * (dP - sum(dP * P))
dS_block = SoftmaxBackward(dP_block, out_block);
// 累积到grad_q
grad_q_block += Gemm(dS_block, K);
// 第2步:重计算Q×K^T,再反向传播到grad_k和grad_v
for each block:
# 重新算前向的scores
scores_block = Gemm(Q_block, K_block.T) * scale;
# 链式法则反向
grad_k_block += Gemm(Q_block.T, dS_block);
grad_v_block += Gemm(dS_block.T, V_block);
}
重计算的代价是多算一遍前向,但换来了显存从O(N²)降到O(N)。达芬奇架构的算力充沛,显存带宽才是瓶颈——这个trade-off划算。
算子怎么注册到GE
写完计算逻辑,还得让框架能调到它。opplugin/flash_attention_op.cc负责把FlashAttention注册到GE图引擎:
// flash_attention_op.cc 算子注册(伪代码)
IMPLEMT_INFERFUNC(FlashAttentionScore, FlashAttentionScoreInfer) {
// GE需要知道输出的shape,推导逻辑:
// output.shape == Q.shape
auto q_shape = op.GetInputDesc(0).GetShape().GetDims();
op.UpdateOutputDesc("output", TensorDesc(q_shape, FORMAT_ND, DT_FLOAT16));
return GRAPH_SUCCESS;
}
REGISTER_OP("FlashAttentionScore")
.Input("q: float16")
.Input("k: float16")
.Input("v: float16")
.Output("output: float16")
.Attr("head_num: int")
.Attr("input_layout: string")
.Attr("scale: float")
.InferShapeFunction(FlashAttentionScoreInfer);
注册完成后,torch_npu调用npu_flash_attention时,GE根据算子名找到这个注册,把计算委托给opkernel/flash_attention/下的实现。
跟其他仓库的关系
Ascend C(达芬奇架构的C++编程接口)
↑
catlass(算子模板库:GEMM、softmax、reduce)
↑
ops-transformer(算子实现:FlashAttention、RMSNorm等)
↑
ge(图引擎:算子编排、图优化、算子融合)
↑
torch_npu(PyTorch适配:npu_flash_attention等接口)
开发FlashAttention变体时,逻辑改ops-transformer,分块策略改catlass的模板参数,图优化改ge。
编译和验证
改完代码后的编译流程:
cd ops-transformer
mkdir build && cd build
source /usr/local/Ascend/ascend-toolkit/set_env.sh
cmake .. -DCANN_INSTALL_PATH=/usr/local/Ascend/ascend-toolkit \
-DCMAKE_BUILD_TYPE=Release
# 只编译FlashAttention单个算子,比全量编译快
make flash_attention -j8
# 验证编译产物
ls output/opkernel/libflash_attention*.so
替换到torch_npu路径或指定自定义算子库路径后,用数值验证脚本确认改动没有引入回归。
克隆ops-transformer仓库,按opkernel/flash_attention/→opplugin/flash_attention_op.cc的顺序读代码。tiling.cc的分块策略和score.cc的在线softmax是核心,理解了这两块就抓住了FlashAttention昇腾实现的主线。
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐



所有评论(0)