本文基于昇腾CANN和昇腾NPU,围绕 TorchAir 与 PyTorch 适配 技术展开。

torch_npu 让 PyTorch 代码能在 NPU 上跑——但每个算子还是要逐一遍历 Python 解释器。TorchAir 走得更远:对接 PyTorch 2.0 的 torch.compile 接口,把 PyTorch 的动态图一次性编译成 CANN 的静态执行计划,绕过了 ONNX 导出和 ATC 的算子选择环节。

PyTorch 的动态图意味着每次前向都可能走不同的算子分支。torch.compile 的做法是"录制"——跑一次前向,记录实际走过的算子序列和 Tensor 形状,生成 FX Graph。TorchAir 把这个 FX Graph 转译成 GE 能消化的图表示,然后由 GE 跑完剩余的全套图优化流水线。


TorchAir 的编译流程

FX Graph → GE 的翻译分三步:

第一步:图标准化。 把 PyTorch 高级算子展开成基础算子。nn.MultiheadAttention 拆成 Reshape、Transpose、MatMul、Softmax——这是 GE 融合模式匹配的前提。高级算子的抽象不给 GE 看,GE 只看基础算子才能做融合。

第二步:设备映射。 每个 PyTorch 算子映射到 Ascend Kernel。有原生映射的(MatMul、Softmax、LayerNorm)直接走;没有的走 CPU Fallback(Tensor 搬回 CPU 执行再搬回来)或等价替换(einsumbmm + permute)。

第三步:编译提交。 翻译好的 GE 图走完完整的图优化流水线——常量折叠、算子替换、模式融合、Buffer 分配——生成 .om。这个 .om 跟 ATC 编译出来的没有本质区别,Runtime 加载和执行逻辑完全一样。

编译路径对比:

ATC 路径:
  PyTorch → ONNX → ATC → GE → OM → Runtime
  5 步,ONNX 环节可能丢失算子语义

TorchAir 路径:
  PyTorch → torch.compile → FX Graph → TorchAir → GE → OM → Runtime
  4 步(少了 ONNX 导出和 ATC 算子选择),算子语义保留更完整

跟 torch_npu 的差异

torch_npu 是后端适配——PyTorch 算子逐一到 CANN Kernel,Graph Compiler 不参与,图的执行流程还是 PyTorch Runtime 管。TorchAir 是图编译——整张图交给 GE 管,融合、Buffer 复用、Stream 调度全部在 GE 侧完成。

性能差异在动态 Shape 下最明显。torch_npu 每次 Shape 变化要重新做算子 dispatch——比如 Batch Size 从 1 变 4,所有 MatMul 的 Kernel 都要重新选。TorchAir 在编译期就标注了动态维度,GE 生成了带 Shape 分支的执行计划——运行时几乎零 dispatch 开销。长序列推理(64K-128K)下 TorchAir 比 torch_npu 快 8-12%。


当前状态和适用场景

TorchAir 还在活跃开发中。算子覆盖率正在追赶 ATC 路径的成熟度——一些 PyTorch 的冷门算子(比如 torch.special.*)只有 CPU Fallback。

适用场景:不想碰 ONNX 导出和 ATC 参数调试的 PyTorch 开发者、动态 Shape 场景(Batch/SeqLen 运行时变化)、需要 torch.compile 的既有工作流。不适合场景:需要极致算子级控制、依赖大量自定义 Kernel 的项目。



动态图编译的边界

TorchAir 的"录制一次,重复执行"策略在静态图场景下完美——ResNet、BERT 这类固定 Shape 的模型几乎零 overhead。但在动态图场景(if 分支、循环长度可变)下,录制只捕获了一次前向的具体路径。如果后续前向走了不同的分支——TorchAir 回退到重新录制 + 重新编译。

Graph Compiler 对这个回退做了缓存:不同 Shape 的编译结果各自存一份。同一个模型跑 3 种 Batch Size——TorchAir 存 3 份编译缓存。首次编译约 2-5 秒(取决于图大小),后续命中缓存直接走——延迟跟 ATC 编译好的 OM 一样。这就是为什么 torch.compile 的第一个 Batch 慢、后续 Batch 快。

参考仓库

TorchAir PyTorch 编译后端

torch_npu PyTorch 适配

GE 图引擎

CANN 学习中心

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐