您的位置:首页 > 产品设计 > UI/UE

Improved Techniques for Training GANs

2017-03-02 20:14 429 查看

Improved Techniques for Training GANs

paper

code

Introduce

对抗网络主要有两个应用:半监督学习和生成视觉相似图片。对抗网络的目的要训练生成网络G(z;θG),输入噪声z生成x=G(z;θG),x为一幅图片,并且x满足真实的数据分布pdata(x)。判别网络D(x)判断生成的图片真假。

对抗网络的双方都想要最小化自己的损失函数。假设JD(θ(D),θ(G))是判别网络的损失函数,JG(θ(D),θ(G))是生成网络的损失函数。纳什均衡点就是参数空间的点(θ(D),θ(G)),JD对于θ(D)取得最小值,JG对于θ(G)取得最小值。对抗网络中,θ(D)的更新减少了J(D),但是同时又增加了J(G),θ(G)的更新减少了J(G),同样又增加了J(D)。例如下面的例子,一个网络想要通过修改x来最小化xy,另一个网络想要通过修改y来最小化−xy,使用梯度下降的方法会进入一个稳定的轨道中,并不会收敛到(0,0)点。

对抗网络的目的需要在高维非凸的参数空间中,找到一个纳什均衡。但是GAN网络使用梯度下降的方法只会找到低的损失,不能找到真正的纳什均衡。本论文中,作者通过引入了一些方法,提高网络的收敛。

Toward Convergent GAN Training

Feature Matching

原始的GAN网络的目标函数需要最大化判别网络的输出。作者提出了新的目标函数,motivation就是让生成网络产生的图片,经过判别网络后的中间层的feature 和真实图片经过判别网络的feature尽可能相同。假定f(x)为判别网络中间层输出的feature map。生成网络的目标函数定义如下:

||Ex∼pdataf(x)−Ez∼pz(z)f(G(z))||22

判别网络按照原来的方式训练。相比原先的方式,生成网络G产生的数据更符合数据的真实分布。作者虽然不保证能够收敛到纳什均衡点,但是在传统GAN不能稳定收敛的情况下,新的目标函数仍然有效。个人觉得,判别网络从输入到输出逐层卷积,pooling,图片信息逐渐损失,因此中间层能够比输出层得到更好的原始图片的分布信息,拿中间层的feature作为目标函数比输出层的结果,能够生成图片信息更多。可能采用这种目标函数,生成的图片会效果会更好。

MiniBatch discrimination

判别网络如果每次只看单张图片,如果判断为真的话,那么生成网络就会认为这里一个优化的目标,导致生成网络会快速收敛到当前点。作者使用了minibatch的方法,每次判别网络输入一批数据进行判断。假设f(x)∈RA表示判别网络中间层的输出向量。作者将f(x)乘以矩阵T∈RA×B×C,得到一个矩阵Mi∈RB×C。计算矩阵Mi每行的L-1距离,得到cb(xi,xj)=exp(−||Mi,b−Mj,b||L1∈R 。

定义输入xi的输出o(xi)如下:

o(xi)b=∑j=1ncb(xi,xj)∈R o(xi)=[o(xi)1,o(xi)2,o(xi)3,......o(xi)n]o(X)∈Rn×B

将o(xi)作为输入,进入判别网络下一层的输入。

Historical averaging

在生成网络和判别网络的损失函数中添加一个项:

||θ−1t∑i=1tθ[i]||2

公式中θ[i]表示在i时刻的参数。这个项在网络训练过程中,也会更新。加入这个项后,梯度就不容易进入稳定的轨道,能够继续向均衡点更新。

One-side label smooth

将正例label乘以α,, 负例label乘以β,最优的判别函数分类器变为:

D(x)=αpdata(x)+βpmodel(x)pdata(x)+pmodel(x)

本文中作者将正例乘以α, 负例乘0。这里我也没看明白,如果后面明白以后, 持续更新……。

Virtual batch normalization

BN使用能够提高网络的收敛,但是BN带来了一个问题,就是layer的输出和本次batch内的其他输入相关。为了避免这个问题,作者提出了一种新的bn方法,叫做virtual batch normalization。首先从训练集中拿出一个batch在训练开始前固定起来,算出这个特定batch的均值和方差,进行更新训练中的其他batch。VBN的缺点也显而易见,就是需要更新两份参数,比较耗时。

Semi-supervised learning

标准的分类网络将数据x输出为可能的K个classes,然后对K维的向量使用softmax:pmodel(y=j|x)=exp(lj)∑kk=1exp(lk)。标准的分类是有监督的学习,模型通过最小化交叉熵损失,获得最优的网络参数。

对于GAN网络,可以把生成网络的输出作为第K+1类,相应的判别网络变为K+1类的分类问题。用Pmodel(y=K+1|x)表示生成网络的图片为假,用来代替GAN的1−D(x)。对分类网络,只需要知道某一张图片属于哪一类,不用明确知道这个类是什么,通过pmodel(y∈1,2,...,k|x)就可以训练。

所以损失函数就变为了:

L=−Ex,y∼pdata(x,y)[logpmodel(y|x)]−Ex∼G[logpmodel(y=K+1|x)]=Lsupervised+Lunsupervised,

Lsupervised=−Ex,y∼pdata(x,y)[logpmodel(y|x)]

Lunsupervised=−Ex∼G[logpmodel(y=K+1|x)]

如果把D(x)=1−pmodel(y=K+1|x),上述无监督的表达式就是GAN的形式:

Lunsupervised=−Ex∼pdata(x)logD(x)+Ez∼noiselog(1−D(G(Z)))

Experiment

作者在mnist, cifar10, svhn数据集上做了实验,在这里只贴了cifar10的实验结果了。



Conclusion

表示大牛写的文章,高度太高,很难理解,博客中有错误的地方,希望大家能多多指教,共同讨论。

Reference

纳什均衡 :https://en.wikipedia.org/wiki/Nash_equilibrium

Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen Improved Techniques for Training GANs
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  gan 深度学习