昇思训练营day10学习心得-ResNet50图像分类
定义visualize_model函数,使用训练好的模型对数据集进行预测,并将预测结果可视化。2.
学习目标
给定一张图像(猫、狗、飞机、汽车等等),判断图像所属的类别。
数据集准备与加载
使用download接口下载并解压。

构建网络
构建残差网络结构,然后通过堆叠残差网络来构建ResNet50网络。
1.构建残差网络结构
定义残差网络的基本单元(Building Block或Bottleneck),包括主分支的卷积操作、shortcuts的连接方式以及,ReLU激活函数。
说明:
残差网络结构主要由两种,一种是Building Block,适用于较浅的ResNet网络;另一种是Bottleneck,适用于层数较深的ResNet网络。
残差网络由两个分支构成:一个主分支,一个shortcuts。主分支通过堆叠一系列的卷积操作得到,shotcuts从输入直接到输出。
2.构建ResNet50网络
使用定义好的残差网络单元,堆叠成完整的ResNet50网络。设置输入图像的尺寸、各层的卷积核大小、步长等参数,确保网络结构正确。
模型训练与评估
1.调用resnet50构造ResNet50模型,并设置pretrained参数为True,将会自动下载ResNet50预训练模型,并加载预训练模型中的参数到网络中。
2.定义优化器和损失函数,逐个epoch打印训练的损失值和评估精度,并保存评估精度最高的ckpt文件。
注意事项:由于CIFAR-10数据集只有10个类别,而预训练模型的全连接层输出大小为1000,在使用该数据集进行训练时,需要将加载好预训练权重的模型全连接层输出大小重置为10。
可视化模型预测
1.定义可视化函数:定义visualize_model函数,使用训练好的模型对数据集进行预测,并将预测结果可视化。
2.预测与可视化:若预测字体颜色为蓝色表示为预测正确,预测字体颜色为红色则表示预测错误。
鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。
更多推荐

所有评论(0)