1. 模型介绍

  • 模型简介:CycleGAN用于在没有成对图像的情况下学习图像从一个域转换到另一个域。这项技术在域迁移和图像风格迁移中非常有用。
  • 模型结构:CycleGAN由两个对称的生成对抗网络(GAN)组成,包括生成器和判别器。每个生成器将图像从一个风格转换到另一个,而判别器则区分真实图像和生成图像。

2. 损失函数

  • 循环一致性损失(Cycle Consistency Loss):这是CycleGAN中最重要的损失函数,确保图像在经过风格转换来回后尽可能接近原始图像。

3. 数据集

  • 数据来源:使用的数据集来源于ImageNet,特别使用了苹果和橘子的图片。
  • 数据预处理:包括随机裁剪、水平翻转和归一化,并将预处理后的数据转换为MindRecord格式。

4. 数据集下载和加载

  • 下载:使用download接口下载数据集,并自动解压到当前目录。
  • 加载:使用MindSpore的MindDataset接口读取和解析数据集。

5. 可视化

  • 使用matplotlib模块可视化部分训练数据。

6. 构建生成器

  • 生成器基于ResNet模型结构,根据输入图片的大小采用不同数量的残差块。

7. 构建判别器

  • 判别器是一个二分类网络,使用PatchGANs模型,通过一系列卷积层和激活函数来区分图像真伪。

8. 优化器和损失函数

  • 为生成器和判别器设置不同的优化器,定义了对抗损失和循环一致性损失。

9. 前向计算

  • 定义了生成器和判别器的前向计算过程,包括生成假图像和计算损失。

10. 计算梯度和反向传播

  • 使用MindSpore的value_and_grad函数计算梯度,并进行反向传播更新参数。

11. 模型训练

  • 训练过程包括训练判别器和生成器,记录损失并定期保存检查点。

12. 模型推理

  • 加载训练好的生成器网络模型参数,对新图像进行风格迁移,并展示原始图像和生成图像。
Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐