APP下载

基于Wasserstein距离的双向学习推理

2020-07-23花强刘轶功张峰董春茹

关键词:散度相似性度量

花强,刘轶功,张峰,董春茹

(河北大学 河北省机器学习与计算智能重点实验室, 河北 保定 071002)

近年来,基于无监督学习的深度学习技术受到了越来越多学者的关注[1-4],其中生成对抗网络(generative adversarial networks,GAN)[5-9]和变分自编码器(variational auto-encoder,VAE)[10]是高维复杂数据建模的2类最重要的深度生成模型. GAN模型的优势在于不需要对生成分布进行显式表达,避免了VAE模型中复杂的马尔可夫链采样和推理计算,可以训练得到生成高品质样本的生成模型.然而,原始GAN及其众多变种不包含从数据空间到隐变量空间的映射,缺少有效的推理机制,并且缺乏完备的理论保障,从而使得GAN的训练需要谨慎地选择各个超参数.进一步的研究发现,GAN对生成样本多样性和准确性的惩罚不平衡,有可能导致生成器倾向于重复生成少数几种样本,出现模式坍塌(mode collapse)[11-12]问题.因此,Donahue等[13]受VAE的编码器原理启发,提出了一种双向生成对抗网络(adversarial feature learning,BiGAN).BiGAN在原始GAN框架的基础上,引入编码器实现了隐变量特征空间的学习和推理机制.同一时期,Dumoulin等[14]独立的提出了与BiGAN类似的模型——对抗学习推理(adversarially learned inference,ALI),ALI将一个称为推理器的编码网络和一个深度定向生成网络集成在GAN的框架下共同训练.模型结合了GAN和VAE的部分优点,具有良好的学习性能,与当前流行的自监督和弱监督特征学习方法[15-16]相比也具有一定的竞争力.由于BiGAN和ALI模型在优化目标函数时,需要最小化真实数据分布和生成样本分布的差异,通常利用分布散度作为基本度量,如Jensen-Shannon(JS)散度或Kullback-Leibler(KL)散度.然而,真实数据分布和生成样本分布的支撑集是高维空间中的低维流形时,2个分布重叠部分的测度为0,这将导致在训练模型迭代过程中,出现梯度为零或者无穷大的情况,从而使得生成器无法接收有效的梯度信息,导致训练失败,影响模型鲁棒性.基于此原因,Arjovsky等[17]提出了Wasserstein GAN(WGAN)模型.WGAN使用Wasserstein距离代替GAN模型损失函数中的KL散度作为衡量2个概率分布之间的相似性.理论和实验表明,该方法能够在一定程度上缓解GAN模型训练过程中的梯度消失和爆炸现象,但是WGAN不具备隐含特征学习功能,容易出现模型坍塌的问题.

为了同时解决上述生成模型存在的模式坍塌和梯度消失爆炸问题,本文提出了一种基于Wasserstein距离的双向学习推理模型(Wasserstein bidirectional learned inference,WBLI).WBLI使用Wasserstein距离代替BiGAN中的KL散度作为衡量概率分布差异的度量;同时,WBLI模型由生成器、编码器和判别器3个网络模块构成,其中生成器和编码器在数据特征空间和与之对应的隐变量空间的联合分布之间建立双向联系,而判别器度量了2个联合分布的Wasserstein距离;最后,WBLI采用了交替迭代算法对网络参数进行训练.在MNIST和Fashion MNIST数据集上的实验结果表明,WBLI模型可以有效缓解基于KL散度的模型在训练过程中梯度消失或梯度爆炸的缺陷;此外,WBLI通过引入有学习数据样本内在特征的逆映射的编码器结构,具有类似于BiGAN和ALI模型具有的隐式正则化、模式覆盖等优点.

本文的主要贡献:1)提出了一种基于Wasserstein距离的双向学习推理模型(WBLI),该模型缓解了基于KL散度的BiGAN及ALI训练过程中梯度消失或梯度爆炸问题,从而提高了模型对于样本分布的鲁棒性;同时WBLI一定程度可以缓解WGAN中的模式坍塌问题.2)从结构及实验2方面将WBLI与BiGAN和WGAN进行了深度比较,结果表明WBLI从模型功能和图像生成效果上都有一定的提高.

1 相关模型

首先介绍双向生产对抗网络模型及Wasserstein距离,并引入本文工作所需的主要数学符号和基础概念.

1.1 双向生成对抗网络

GAN最早于2014年由Goodfellow等[5]提出,是一种实现复杂数据分布学习的无监督生成模型.该模型主要由生成器网络G和判别器网络D两部分构成,其中生成器将输入的随机噪声映射为生成样本,而判别器同时接收真实样本和生成样本,并判别输入样本的真伪(即判别样本是真实样本还是生成样本).在GAN模型的训练过程中,通过构建目标函数引入竞争机制让这2个网络同时得到优化,最终使得生成器生成与真实样本数据分布足够相似的新数据分布.GAN模型的结构如图1所示.

图1 GAN 结构Fig.1 Structure of GAN

设q(x)为真实数据分布,其中x∈ΩX,设p(z)为一个固定的隐编码分布,其中z∈ΩZ,通常定义为简单分布,例如标准正态分布p(z)=N(0,1),生成器G∶ΩZ→ΩX:可以将隐编码分布映射到数据分布,D(x)代表x来自于真实数据分布q(x)而不是生成样本分布的概率.据此GAN网络的优化目标函数如下:

(1)

图2 BiGAN结构Fig.2 Structure of BiGAN

(2)

BiGAN采用与GAN相同的基于交替梯度的EM优化算法来优化目标函数[18].理论上,在BiGAN达到最优解时,即KL散度收敛达到最小的情况下,可认为所有边缘分布和所有条件分布都已达到匹配.然而,如引言部分所述,BiGAN目标函数中采用KL散度衡量数据分布间的差异,在某些情况下会出现梯度爆炸的情形,从而导致训练失败,影响模型鲁棒性[19].

1.2 Wasserstein距离

生成模型的传统设计方法依靠最大似然估计,或者最小化未知的真实数据分布q(x)和生成样本分布pG(x)之间的KL散度

(3)

文献[19]中证明当处理2个由低维流形支持的分布时,那么这2个低维流形将会具有极小重叠甚至没有重叠,这意味着KL散度在大部分区域是无意义的,即KL(q(x)‖pG(x))=∞,并且JS散度将变为常数log2,这将导致判别器损失函数的梯度为无穷或零,从而导致模型训练失败.因此,Arjovsky等[17]通过全面的理论分析,把Wasserstein距离与其他广受关注的度量概率分布的距离和散度相比,用Wasserstein距离替换原始GAN中的KL散度,提出了WGAN模型,其采用的Wasserstein距离定义为

(4)

其中,∏(pr,pg)是以pr和pg为边缘分布的所有可能的联合概率分布的集合.对于每个联合分布γ(x,y),都可以通过采样的方法获得(x,y)~γ.计算(x,y)的范数‖x-y‖,这样就可以计算每个联合分布γ(x,y)的期望值E(x,y)~γ[‖x-y‖].W(pr,pg)为γ(x,y)期望的下确界,更直观的说,它表示为了将pr移动到pg需要将x移动到y的最小距离或能量.

Wasserstein距离相对KL散度与JS散度具有相对平滑特性,即使2个分布之间没有交集,Wasserstein距离亦能够正确度量它们之间的差异,进而产生有意义的梯度.因此,WGAN能有效缓解基于KL散度或JS散度的GAN模型的梯度消失或梯度爆炸问题.然而,理论上目前的深度神经网络只能够逼近连续映射,而GAN训练过程中,目标映射是具有间断点的非连续映射,不在深度神经网络的可表示泛函空间之中,这导致了收敛困难,从而产生了模式坍塌,可见WGAN并没有完全克服GAN中的模式坍塌问题.

2 基于Wasserstein距离的双向学习推理

基于以上分析,本文将Wasserstein距离引入到BiGAN中,提出了一种基于Wasserstein距离的双向学习推理模型(即WBLI),以综合BiGAN和Wasserstein距离的优点,从而获得更加稳定的学习模型.

(5)

其中,θG,θE,θD分别表示生成器G、推理器E和判别器D的模型参数.WBLI使用对抗方式联合训练一个生成器与一个推理器,生成器G将服从简单分布的隐变量映射到数据空间,而推理器E将训练样本从数据空间反映射回隐变量空间.因此,对抗博弈在G,E与D之间展开.因为WBLI具有逆映射结构,所以推理器在编码的过程中将相似样本的隐变量聚在一起,使得流形连续,达到隐式正则化的效果,从而可以提高模型泛化能力.

γ([x1,x2],[y1,y2])∈∏(q(x,z),p(x,z)),

此时Wasserstein距离的计算公式如(6)式

(6)

(7)

其中

由于映射函数f的参数可调节,故f要满足在γ∉π时使整体附加项趋于无穷,使得supf无解,从而达到类似(6)式的约束效果.而当γ∈π时,s和x都是从同一个边缘分布中采样,s和x两个随机变量分布的期望值相等,故Ef(s)-Ef(x)=0,t和y同理,从而整体附加项等于零.这样就成功地去掉了γ∈π的约束.将(7)式2项合并得到(8)式

(8)

由于sup为凸函数且inf为凹函数,根据极小极大原理[20],得到 (9) 式

(9)

(10)

在具体算法实现中,可将函数f(x)用一簇参数为w的神经网络参数化,并采用权重裁剪方法使得函数满足Lipschitz连续,此时求上界的问题便可转化为如下(11)式所表示的优化问题

(11)

接下来生成器要近似地最小化Wasserstein距离,可由此设置最小化损失函数L,由于Wasserstein距离的优良性质,可有效避免判别器的梯度消失问题.再考虑到L的第1项与生成器无关,得到WBLI最终的2个损失函数

LG=-Ex~p(z)[fw(x)],

(12)

(13)

LD是(11)式的相反数,可以指示训练进程,其数值越小,表示真实分布与生成分布的Wasserstein距离越小,WBLI训练得越好.在训练生成器时,首先固定判别器的参数,从正态分布N(0,1)中采样m个样本作为一个批次的训练数据输入判别器,然后根据(13)式计算生成器损失,同样采用RMSProp[21]算法更新其参数.由于更优的判别器可以反向传播给生成器更准确的梯度信息,因此从训练开始,在每次更新生成器之前,均需更新判别器n次,以使判别器D更快收敛.完整的训练过程如算法1所示.

算法1 WBLI输入 z: 隐变量;T: 训练数据集;m: 批次大小;α: 学习率;c: 判别器的梯度剪裁数.n: 生成器和推理器优化过程中的判别器更新次数.输出 判别器参数θD;生成器参数θG;推理器参数θE1)随机初始化θD和θG2)重复3) for t=0,…,n do 4) z1,…,zm~p(z)5) x1,…,xm~q(x)6) z^1,…,z^m~qE(z^|x=xi),i=1,…,m7) x^1,…,x^m~pG(x^|z=zj),j=1,…,m8) LD←1m∑mi=1fθD(D(xi,z^i))-1m∑mi=1fθD(D(x^i,zi))9) θD←θD+α×RMSProp(θD,∇θDLD)10)剪裁θD,将其限制在[-c,c]范围内11)结束重复12)z1,…,zi~p(z)13) LG←1m∑mi=1fθD(D(xi,z^i))14) θG←θG+α×RMSProp(θG,∇θGLG),θE←θE+α×RMSProp(θE,∇θELG)15) 结束至判别器收敛

3 实验结果与分析

3.1 实验设置

本文实验运行操作系统为64位Windows10,编程语言为Python3.6,基于TensorFlow开源框架,使用的数据集为MNIST[22]和FashionMNIST[23].MNIST是手写体数据集,分为60 000个训练样本和10 000个测试样本.所有数字图像都标准化为28X28像素的固定大小.每个像素由0到255之间的值表示,其中0为黑色,255为白色,介于两者之间的数值代表不同的灰色影像.FashionMNIST由德国Zalando科技公司的研究部门提供,FashionMNIST的大小、格式和训练集、测试集划分与原始的MNIST完全一致,涵盖了来自10种类别的共70 000个不同商品的正面图片.

3.2 生成样本质量与Wasserstein距离

众所周知,在训练GAN和BiGAN过程中,生成器G的目标是尽量生成真实的图片去欺骗判别器D;而D的目标是尽量把G生成的图片和真实的图片区分开.这样,G和D就构成了一个动态的“博弈过程”.然而GAN和BiGAN在训练过程中没有任何指示训练进度的指标,只能基于经验和生成样本的效果来判断模型是否收敛.本文提出的WBLI模型在引入Wasserstein距离度量后,自动建立了生成模型训练的进程监视指标.本节实验就是验证生成样本质量同Wasserstein距离的正比关系.

在本节实验,WBLI模型中生成器G、推理器E、判别器D均由3层神经元网络实现.设定3个网络的隐含层神经元个数均取128,则生成网络G输入、隐含和输出层神经元个数依次为10-128-784,推理网络E与G的神经元结构镜像对称为784-128-10,判别网络D接收G和E的联合数据进行真伪判断,因此网络结构为794-128-1.模型学习率设为0.000 1.根据算法1,每训练5次判别网络,则更新训练1次生成网络和推理网络,输出损失函数并记录生成样本.抽取了5张不同迭代阶段的生成样本和对应Wasserstein距离值,关系如图3所示,可以直观地看到,判别器所输出的Wasserstein距离与生成器的生成图片的质量高度相关(更多随迭代次数生成的样本序列见图4).随着Wasserstein距离的不断减小,生成样本的质量逐渐提高.因此,Wasserstein距离可作为训练阶段进程评判指标,这也是WBLI相比于其他基于KL散度模型的一个优势.

图3 生成样本质量和Wasserstein距离的关系Fig.3 Relationship between the quality of generated images and Wasserstein distance

图4 生成样本序列Fig.4 Generating sample sequence graphs

3.3 WBLI模型生成样本多样性测试

如前所述,GAN和WGAN不具备隐含特征学习功能,易出现模式坍塌问题,即无论输入模型的简单随机分布如何变化,生成器生成的样本都缺乏多样性,不能支撑数据空间.WBLI模型能否生成多样性样本,是评估该模型优劣的重要指标.

在计算机图像领域,定量评价图像相似度的方法有很多,例如常用方法之一是基于单尺度的结构相似性指标SSIM[24-25]

SSIM(x,y)=[l(x,y)]α·[c(x,y)]β·[x(x,y)]γ,

(14)

其中x和y分别表示2个视窗图像,l(x,y),c(x,y)和s(x,y)分别表示2个视窗图像的亮度、对比度和结构相似度度量,其具体计算公式见文献[24-25];而参数α,β,γ用于控制3部分度量在SSIM中所占比例,一般设α=β=γ=1.当计算2张影像的结构相似性值时,会创建一个局部性视窗,并按式(14)计算视窗内图像的结构相似性值,每次以像素为单位移动视窗,直到整张影像每个位置的SSIM都计算完毕并求取均值,作为2张影像的结构相似性指标.而为构建更贴近主观的图像相似性质量评价方法,文献[26]在SSIM的基础上提出了多层级结构相似性指标MS-SSIM[26],其基本思路为同时考虑多个尺度对图像相似度进行度量,具体定义为

(15)

其中M为尺度层数,lM(x,y)为视窗图像x和y在M层上的亮度相似性度量,而cj(x,y)和sj(x,y)分别为视窗图像x和y在第j层尺度上的对比度和结构相似性度量;而参数αM用于控制第M层亮度相似性所占比例,βj和γj为第j层尺度上对比度相似性和结构相似性度量所占的比例.本文采用多层级结构相似性指标MS-SSIM衡量图片集的相似性,其值越小,代表图像集多样性越好.

图5为FashionMNIST训练集和测试集各10类数据的MS-SSIM平均值,可看到大部分类的MS-SSIM值都小于0.25,因此选择0.25作为判断生成样本是否达到真实数据集多样性标准的阈值.同时,通过统计生成样本MS-SSIM值的变化也可监控模式坍塌情况.

图5 数据集的MS-SSIM平均值Fig.5 MS-SSIM average graphs of data sets

图6表明随着训练迭代增加,生成样本的MS-SSIM值逐步减少,最后降低至FashionMNIST数据集的MS-SSIM平均值之下.由此可见,WBLI模型的训练过程稳定,没有产生模式坍塌现象.作为对照,图7给出了BiGAN训练过程中MS-SSIM值的变化趋势.由图7可见,在BiGAN训练过程中生成的样本MS-SSIM值停留在较高水平,并随迭代次数增加有增加趋势,这表明生成样本出现了相似度过高现象,样本趋于单调,从而发生模式坍塌.

图7 BiGAN模式坍塌下MS-SSIM值Fig.7 Mode collapse of BiGAN

图6 WBLI生成样本的MS-SSIM值Fig.6 MS-SSIM graphs for generating samples

为更加直观地体现WBLI生成样本的多样性,图8给出了部分WBLI生成样本的示例.

图8 WBLI生成样本示例Fig.8 Generated samples by WBLI

3.4 生成样本分类识别率

评估生成样本的质量,只是通过图像观察是不规范的.基于同一Le Net-5[25]卷积分类器,分别对WBLI、BiGAN和WGAN模型生成的样本数据在给定类别信息后,进行分类测试.

首先为了对BiGAN,WGAN和WBLI有更全面的分析,表1中对这3种模型从是否包含隐变量编码、梯度稳定性以及判别器输入变量构成3个方面进行了比较.由表1可看出,BiGAN模型可以通过编码器从数据空间学习到隐含特征,并且判别器的输入是包含原始数据和隐变量数据的高维向量,使得网络能够对低层信息x和高层信息z共同进行判别,从而提高了判别能力.WGAN凭借Wasserstein距离的优点和权重裁剪技术避免梯度消失,且其值保持平滑稳定.本文提出的WBLI模型正是同时集成了2类模型的优势.

表1 3种模型功能的比较

分类识别率实验以MNIST手写数字数据集为初始训练集.实验设定3种模型分别以MNIST手写数据集进行训练.将训练好的3种模型分别生成1 000个生成样本,然后输入到同样已经预先训练好的Le Net-5卷积分类器中,3种模型生成样本的分类测试正确率如表2所示,同时表2也列出部分生成样本图像示例.从生成的示例样本可看到,WBLI模型生成的样本更清晰,各类数字间特征更明显,辨识度更高.而识别正确率结果也说明WBLI模型生成的样本更具有真实样本的特征,而这正是因为WBLI综合了WGAN的梯度稳定性与BiGAN的推理结构.

表2 3种生成样本分类识别率

4 总结

在BiGAN模型结构基础上,引入Wasserstein距离代替KL散度用于计算分布间的差异性,建立了基于Wasserstein距离的鲁棒无监督生成式学习模型WBLI,并将基于Wasserstein距离的多维优化问题转化为可求解形式,得到了模型的生成器和判别器的对抗损失函数.一方面,由于WBLI采用的Wasserstein距离具有整体平滑的特性,理论上解决了当前基于KL散度或JS散度的无监督生成模型(如GAN,BiGAN)的梯度爆炸或梯度消失问题;另一方面,借鉴BiGAN中引入推理器E从而使得模型可以有效缓解模式坍塌问题.WBLI解决了原模型训练不稳定的问题,建立了一个可靠的与生成样本的质量高度相关的训练进程指标,实验结果验证了上述优点.

猜你喜欢

散度相似性度量
一类上三角算子矩阵的相似性与酉相似性
鲍文慧《度量空间之一》
定常Navier-Stokes方程的三个梯度-散度稳定化Taylor-Hood有限元
浅析当代中西方绘画的相似性
代数群上由模糊(拟)伪度量诱导的拓扑
突出知识本质 关注知识结构提升思维能力
度 量
基于f-散度的复杂系统涌现度量方法
基于隐喻相似性研究[血]的惯用句
静电场边界问题专题教学方法探索及推广应用