通过MindSpore的API来快速实现一个简单的深度学习模型
MindSpore提供基于Pipeline的,通过和实现高效的数据预处理。在本教程中,我们使用Mnist数据集,自动下载完成后,使用提供的数据变换进行预处理。在下载Mnist数据集后,使用提供的数据变换进行预处理。首先下载:数据下载完成后,获得数据集对象。打印数据集中包含的数据列名,用于dataset的预处理。
MindSpore提供基于Pipeline的数据引擎,通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。在本教程中,我们使用Mnist数据集,自动下载完成后,使用mindspore.dataset提供的数据变换进行预处理。
在下载Mnist数据集后,使用mindspore.dataset提供的数据变换进行预处理。
首先下载:
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
数据下载完成后,获得数据集对象。
train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')
打印数据集中包含的数据列名,用于dataset的预处理。
print(train_dataset.get_col_names())
输出为:['image', 'label']
确认无误后,开始处理图片格式:
def datapipe(dataset, batch_size):
image_transforms = [
vision.Rescale(1.0 / 255.0, 0),
vision.Normalize(mean=(0.1307,), std=(0.3081,)),
vision.HWC2CHW()
]
label_transform = transforms.TypeCast(mindspore.int32)
dataset = dataset.map(image_transforms, 'image')
dataset = dataset.map(label_transform, 'label')
dataset = dataset.batch(batch_size)
return dataset
# Map vision transforms and batch dataset train_dataset = datapipe(train_dataset, 64) test_dataset = datapipe(test_dataset, 64)
这一步主要是使用map对图像数据及标签进行变换处理,将输入的图像缩放为1/255,根据均值0.1307和标准差值0.3081进行归一化处理,然后将处理好的数据集打包为大小为64的batch。
之后使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。

输出准确,确认无误。
开始构建网络。
mindspore.nn类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承nn.Cell类,并重写__init__方法和construct方法。__init__包含所有网络层的定义,construct中包含数据(Tensor)的变换过程。
模型训练
在模型训练中,一个完整的训练过程(step)需要实现以下三步:
-
正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
-
反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
参数优化:将梯度更新到参数上。
-
MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:
-
定义正向计算函数。
-
使用value_and_grad通过函数变换获得梯度计算函数。
-
定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的loss值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。
代码与输出如下:

模型训练完成后,需要将其参数进行保存。还要加载:

加载后的模型,可直接用于推理:

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐



所有评论(0)