学习目标

给定一张图像(猫、狗、飞机、汽车等等),判断图像所属的类别。

数据集准备与加载

使用download接口下载并解压。

构建网络

构建残差网络结构,然后通过堆叠残差网络来构建ResNet50网络。

1.构建残差网络结构

定义残差网络的基本单元(Building Block或Bottleneck),包括主分支的卷积操作、shortcuts的连接方式以及,ReLU激活函数。

说明:

残差网络结构主要由两种,一种是Building Block,适用于较浅的ResNet网络;另一种是Bottleneck,适用于层数较深的ResNet网络。

残差网络由两个分支构成:一个主分支,一个shortcuts。主分支通过堆叠一系列的卷积操作得到,shotcuts从输入直接到输出。

2.构建ResNet50网络

使用定义好的残差网络单元,堆叠成完整的ResNet50网络。设置输入图像的尺寸、各层的卷积核大小、步长等参数,确保网络结构正确。

模型训练与评估

1.调用resnet50构造ResNet50模型,并设置pretrained参数为True,将会自动下载ResNet50预训练模型,并加载预训练模型中的参数到网络中。

2.定义优化器和损失函数,逐个epoch打印训练的损失值和评估精度,并保存评估精度最高的ckpt文件。

注意事项:由于CIFAR-10数据集只有10个类别,而预训练模型的全连接层输出大小为1000,在使用该数据集进行训练时,需要将加载好预训练权重的模型全连接层输出大小重置为10。

可视化模型预测

1.定义可视化函数:定义visualize_model函数,使用训练好的模型对数据集进行预测,并将预测结果可视化。

2.预测与可视化:若预测字体颜色为蓝色表示为预测正确,预测字体颜色为红色则表示预测错误。

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐