1.模型简介

生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来

复杂分布上无监督学习最具前景的方法之一。GAN由Ian J. Goodfellow于2014年发明,并在论文

Generative Adversarial Netshttps://proceedings.neurips.cc/paper_files/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf

中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器

(Discriminative Model):

生成器的任务是生成看起来像训练图像的“假”图像

判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同

时训练两个模型——捕捉数据分布的生成模型  𝐺 和估计样本是否来自训练数据的判别模型  𝐷。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会

逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致

时,判别器拥有50%的真假判断置信度。

用  𝑥代表图像数据,用  𝐷(𝑥)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,

𝐷(𝑥)需要处理作为二进制文件的大小为  1×28×28的图像数据。当  𝑥来自训练数据时, 𝐷(𝑥)数值应

该趋近于  1;而当  𝑥来自生成器时, 𝐷(𝑥)数值应该趋近于  0。因此  𝐷(𝑥)也可以被认为是传统的二

分类器。

用  𝑧代表标准正态分布中提取出的隐码(隐向量),用  𝐺(𝑧):表示将隐码(隐向量)  𝑧映射到数据空间

的生成器函数。函数  𝐺(𝑧)的目标是将服从高斯分布的随机噪声  𝑧通过生成网络变换为近似于真实

分布  𝑝𝑑𝑎𝑡𝑎(𝑥)的数据分布,我们希望找到  θ使得  𝑝𝐺(𝑥;𝜃)和  𝑝𝑑𝑎𝑡𝑎(𝑥)尽可能的接近,其中  𝜃代表

网络参数。

𝐷(𝐺(𝑧)) 表示生成器  𝐺生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所

述, 𝐷和  𝐺在进行一场博弈, 𝐷想要最大程度的正确分类真图像与假图像,也就是参数  log𝐷(𝑥);

而  𝐺 试图欺骗  𝐷来最小化假图像被识别到的概率,也就是参数  log(1−𝐷(𝐺(𝑧))) 。因此GAN的损失

函数为:

从理论上讲,此博弈游戏的平衡点是 𝑝𝐺(𝑥;𝜃)=𝑝𝑑𝑎𝑡𝑎(𝑥),此时判别器会随机猜测输入是真图像还是

假图像。下面简要说明生成器和判别器的博弈过程:

- 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。

- 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近

生成器生成出来数据分布的数据判定为0。

- 生成器通过优化,生成出更加贴近真实数据分布的数据。

- 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。

在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数

据分布, 𝑧表示隐码, 𝑥表示生成的虚假图像  𝐺(𝑧)。该图片来源于Generative Adversarial Nets。

详细的训练方法介绍见原论文。

2.数据集

1)数据集简介

MNIST手写数字数据集是NIST数据集的子集,共有70000张手写数字图片,包含60000张训练样本

和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道。图片已经预先进行了

尺寸归一化和中心化处理。

本案例将使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图

片。

2)数据集下载

使用download接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要

使用pip install download安装download包。

下载解压后的数据集目录结构如下:

./MNIST_Data/
├─ train
│ ├─ train-images-idx3-ubyte
│ └─ train-labels-idx1-ubyte
└─ test
   ├─ t10k-images-idx3-ubyte
   └─ t10k-labels-idx1-ubyte

数据下载的代码如下:

3)数据加载

使用MindSpore的MnistDatase接口,读取和解析MNIST数据集的源文件构建数据集。然后对

数据进行一些前处理。

4)数据集可视化

通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数

据。

5)隐码构造

为了跟踪生成器的学习进度,在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布

的隐码test_noise输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。

3.模型构建


本案例实现中所搭建的 GAN 模型结构与原论文中提出的 GAN 结构大致相同,但由于所用数据集

MNIST 为单通道小尺寸图片,可识别参数少,便于训练,在判别器和生成器中采用全连接网络架

构和 ReLU 激活函数即可达到令人满意的效果,且省略了原论文中用于减少参数的 Dropout策略和

可学习激活函数 Maxout。

1)生成器

生成器 Generator 的功能是将隐码映射到数据空间。由于数据是图像,这一过程也会创建与真实图

像大小相同的灰度图像(或 RGB 彩色图像)。在本案例演示中,该功能通过五层 Dense 全连接层来

完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,

使其返回 [-1,1] 的数据范围内。注意实例化生成器之后需要修改参数的名称,不然静态图模式下会

报错。

2)判别器

如前所述,判别器 Discriminator 是一个二分类网络模型,输出判定该图像为真实图的概率。主要

通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回

[0, 1] 的数据范围内,得到最终概率。注意实例化判别器之后需要修改参数的名称,不然静态图模

式下会报错。

3)损失函数和优化器

定义了 Generator 和 Discriminator 后,损失函数使用MindSpore中二进制交叉熵损失函数

BCELoss ;这里生成器和判别器都是使用Adam优化器,但是需要构建两个不同名称的优化器,分

别用于更新两个模型的参数,详情见下文代码。注意优化器的参数名称也需要修改。

4.模型训练

训练分为两个主要部分。

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的

方法,通过提高其随机梯度来更新判别器,最大化  𝑙𝑜𝑔𝐷(𝑥)+𝑙𝑜𝑔(1−𝐷(𝐺(𝑧))的值。

第二部分是训练生成器。如论文所述,最小化  𝑙𝑜𝑔(1−𝐷(𝐺(𝑧)))来训练生成器,以产生更好的虚假图

像。

在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将隐码批量推送到

生成器中,以直观地跟踪生成器 Generator 的训练效果。

5.效果展示

运行下面代码,描绘D和G损失与训练迭代的关系图:

可视化训练过程中通过隐向量生成的图像。

从上面的图像可以看出,随着训练次数的增多,图像质量也越来越好。如果增大训练周期数,当

epoch 达到100以上时,生成的手写数字图片与数据集中的较为相似。下面我们通过加载生成器网

络模型参数文件来生成图像,代码如下:


6.模型推理

下面我们通过加载生成器网络模型参数文件来生成图像,代码如下:

Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐