一、问题描述

1、运行平台

        CPU是华为的鲲鹏920

2、pytorch版本

        torch-1.10.1 cpu版

3、bug场景

        项目中用到了torch.linalg.solve方法来求解方程组,代码在x86_64的平台上运行没有出现内存泄漏,在华为的鲲鹏920 CPU上运行出现了内存泄漏。

# 新旧两版接口都有内存泄漏
Q = torch.solve(Z, L)[0]      # torch 1.9以前版本
Q = torch.linalg.solve(L, Z)

二、问题解决

1、升级pytorch版本

        开发环境使用了torch 2.4.1,在该版本下这个方法没有内存泄漏,具体哪个版本解决了这个bug没有详细测试。

2、替换该方法

        使用scipy库中的scipy.linalg.solve方法替代了torch.linalg.solve,没有内存泄漏,但方法运行效率稍低,包括将tensor转ndarray和将ndarray转tensor在内,调用耗时是torch.linalg.solve方法的三倍左右,具体使用数据测试该方法,1万次调用,torch耗时860ms,scipy耗时2651ms,不过每次调用耗时都在1ms以内,可以接受。

import scipy.linalg as la

# tensor包含了batch size维度,所以使用了squeeze和unsqeeze
Q = la.solve(L.squeeze(0).numpy(), Z.squeeze(0).numpy())
Q = torch.tensor(Q).unsqueeze(0)

至此问题解决,欢迎留言讨论。

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐