基于隐变量后验生成对抗网络的不平衡学习
2021-06-01何新林戚宗锋李建勋
何新林, 戚宗锋, 李建勋
(1. 上海交通大学 电子信息与电气工程学院, 上海 200240;2. 电子信息系统复杂电磁环境效应国家重点实验室, 河南 洛阳 471003)
数据分布不平衡是影响数据挖掘或者机器学习分类算法性能的一个关键因素.传统的分类算法假设数据分布是均衡的,以提升全局准确率为优化目标.但现实世界中的许多应用如医疗诊断[1]、金融欺诈检测[2]、人脸识别[3]及机械故障检测[4]等均存在数据分布不均衡的情况,某些类别的数目远远超过其他类别,二分类问题中通常分别称为多数类和少数类.在这些问题中,少数类是人们关注的重点,传统分类算法致力于提升全局准确率,导致分类结果偏向多数类而忽视了人们关注的重点.
目前主要从数据层面[5-9]和算法层面[10-11]或者两者结合[12-13]来解决不平衡分类问题.数据层面的解决办法通过对少数类过采样或者多数类欠采样使数据达到均衡,算法层面的解决办法通过设定算法参数如数据加权、代价敏感来强调少数类,两个方法的结合通过采样来平衡数据集,同时改进算法来强调少数类.
尽管现有的解决办法达到了良好的表现[10],鉴于最近几年深度网络生成模型在表示学习上显现出的巨大优势[14-15],本文关注利用深度神经网络对少数类进行过采样,因为过采样不会丢失数据中重要的信息,而且可以作为预处理步骤来进行可视化或者与算法层面方法相结合.传统的采样方法都是基于线性插值的方式,不能根据数据的概率分布函数进行采样.用过采样方法来解决不平衡分类问题是通过生成少数类样本来使数据达到均衡.最简单的方法是复制现有的少数类样本,这种方法容易导致过拟合.Chawla等[5]提出在选定的少数类样本和它们的K近邻之间进行线性插值来生成少数类样本,这种方法把所有少数类样本等同看待,没有考虑数据内分布的差异性,容易导致生成样本落入多数类区域.Han等[6]提出识别出位于类间边界的难以学习的少数类样本,对每个边界集合样本生成同等数量的样本.He等[7]用自适应方法根据每个少数类样本K近邻中多数类样本的数量来决定对每个少数类样本生成样本的数目,这种方法容易受噪声影响,对落入多数类区域的噪声给予过多关注.Barua等[8]提出识别出那些难以学习的少数类样本,并基于其与多数类样本的欧式距离给每个少数类样本分配权重,再用层次聚类法把少数类样本分为若干簇,在簇内根据权重采样插值生成少数类样本.Douzas等[9]提出利用条件生成对抗网络学习数据的多类分布,再进行少数类过采样.
针对现有的基于插值的过采样算法仅仅利用邻域样本的缺点,本文引入了隐变量模型,提出了一种基于隐变量后验生成对抗网络的过采样(LGOS)算法.生成对抗网络利用了所有少数类样本来学习数据真实概率分布,在隐变量后验上采样克服了基于高斯噪声生成对抗网络生成数据的随机性.同时本文引入了权重缩放因子,提出了与过采样算法相结合的不平衡分类算法TrWSBoost,人工合成的过采样样本和原始样本有很大相关性就相当于迁移源领域样本,原始样本被当作目标领域样本来迭代训练集成分类器.
1 生成对抗网络
生成对抗网络(Generative Adversarial Networks, GANs)是Goodfellow等[16]于2014年提出的一种无监督生成模型.生成对抗网络由两部分组成:生成器G和判别器D,结构如图1所示.这两个网络以对抗方式进行训练,生成器用来学习真实数据的分布,输入是隐变量先验,通常假设为高斯噪声或者均匀噪声,输出为接近真实数据分布的生成数据.判别器是一个二分类器,用来判别输入是真实数据还是生成数据的概率.
生成器G和判别器D的训练目标是相互对抗的.判别器对输入样本进行真假判定,通过训练不断提升自己的分类效果,识别出生成器所生成的样本.生成器希望生成更加真实的样本以混淆判别器,让判别器无法分辨真假.设输入的随机噪声为z,生成器G将随机噪声转换为生成样本G(z).判别器D对输入样本输出D(x)为[0,1]范围内的一个实数,表示输入样本为真实样本的概率值.其损失函数为
Ez~Pz{log[1-D(G(z))]}
(1)
式中:x为真实输入样本;Pr为真实数据分布;Pz为输入噪声分布;E为数据期望.
图1 生成对抗网络结构图Fig.1 Framework of generative adversarial network
两个网络进行迭代训练,理论上最终达到纳什均衡时,生成器G生成的数据分布和真实数据分布相同,判别器D输出概率值为0.5,无法区分真实样本和生成样本.
2 基于隐变量后验的生成对抗网络模型
2.1 隐变量模型
在生成对抗网络中,把高斯噪声或者均匀噪声当作隐变量先验分布,而隐变量真实先验分布和真实后验分布未知,所以生成数据质量具有随机性.变分自编码隐变量模型用近似后验分布代替真实先验分布,运用变分贝叶斯方法,在概率图模型上执行高效的近似推理和学习.均值场方法在很多情况下难以求得后验分布的解析解,变分自编码隐变量模型在概率图框架下形式化这个问题,通过优化对数似然的下界来间接优化最大对数似然.
近似后验分布和真实后验分布的距离用KL散度度量:
DKL(qφ(z|x)‖pθ(z|x))=
Ez~qφ[logqφ(z|x)-logpθ(z|x)]
(2)
式中:φ为变分模型参数;qφ(z|x)为近似后验分布,假设其服从高斯分布;θ为数据生成模型参数;pθ(z|x)为真实后验分布.
通过贝叶斯变换可得变分下界为
L(θ,φ;x)=
logp(x)-DKL[qφ(z|x)‖pθ(z|x)]=
Ez~qφ[logpθ(x|z)]-DKL[qφ(z|x)‖pθ(z)]
(3)
隐变量模型通过编码器得到近似后验分布qφ(z|x)的均值和协方差,在隐空间采样输入解码器重构原始输入数据,误差沿网络反向传播更新网络参数来逼近变分下界.
2.2 基于隐变量后验分布的生成对抗网络模型建立
本文所建立模型中,编码器E从真实样本提取隐变量作为监督信号,在隐空间采样作为信号输入生成器G用来生成和真实数据同分布的样本.隐变量模型的解码器从隐空间采样重构原始输入样本,故可把生成器和解码器结合,提出了一种数据生成模型LGOS.
编码器E输入x,输出为隐变量分布均值和方差,可表示为
z~E(x)=qφ(z|x)
(4)
式中:qφ(z|x)~N(μ,σ),μ、σ分别为隐变量对应高斯分布的均值和方差.
在隐空间采样输入生成器G得到生成数据:
(5)
假设隐变量先验分布为正态分布N(0,I),则变分下界为
L(θ,φ;x)=-Lele-LKL
(6)
式中:Lele为重构误差;LKL为隐变量近似后验分布和先验分布之间的KL散度,具体表示如下.
Lele=-Eqφ(z|x)[logpθ(x|z)]=
(7)
LKL=DKL(qφ(z|x)‖pθ(z))=
(8)
式中:J为隐变量维度;μj和σj分别为样本近似后验分布对应的均值和方差.
判别器D损失为
LD=-Ex~pr{log[D(x)]}-
Ez~qφ(z|x){log[1-D(G(z))]}
(9)
判别器D对真实输入样本输出较大的似然概率值,而对生成器G生成的样本输出小的似然概率值.
生成器G对抗损失为
LDG=-Ez~pz{log[D(G(z))]}
(10)
生成器G和判别器D进行反向迭代,两个模型一直处于对抗训练过程.
2.3 分布自适应
用真实数据和生成数据之间的欧式距离来度量似然函数在很多情况下不适用.因为真实数据和生成数据要服从同分布,在模型中添加边缘分布自适应和条件分布自适应两个限制条件.
边缘分布的差距用最大化均值差异(MMD)度量[17],最大化均值差异把原变量映射到再生希尔伯特空间,在另一空间中求取两个分布的距离.在生成对抗网络中,判别器的目的就是学习数据样本的特征来进行区分,所以在LGOS模型中,用生成对抗网络判别器的最后一个隐层作为特征空间,特征向量的欧式距离即为MMD距离.
(11)
式中:l为判别器最后1个隐层;f为输入数据在第l层对应的特征提取函数.
在LGOS模型中,用一个分类器C获得条件概率,分类器输出激活函数为softmax,输出向量各维度表示样本属于各个类别的概率.条件分布距离损失为
(12)
分类器用原始数据训练,用交叉熵函数作为其损失函数,分类器损失为
(13)
各模块最终损失为
LE=LKL+γ1Lelement+γ2LMMD+γ3LGC
(14)
LG=LDG+γ1Lelement+γ2LMMD+γ3LGC
(15)
LD=-Ex~pr{log[D(x)]}-
Ez~qφ(z|x){log[1-D(G(z))]}
(16)
(17)
式中:γ1、γ2及γ3为超参数,用于调节各部分损失比重大小.
网络结构图如图2所示.
图2 本文LGOS算法网络结构图Fig.2 Framework of proposed LGOS algorithm
2.4 权重缩放的迁移学习模型
以TrAdaboost[20]为基础,提出了改进的带权重缩放因子的TrWSBoost迁移学习分类算法.把生成的少数类样本当作源领域样本,原始训练数据当作目标领域样本,目标是要训练迁移学习集成分类器.
在TrWSBoost模型中,在每一轮迭代时,对于源领域样本,被基学习器错分时,认为这些错分样本是与原始样本不同分布的样本,错分样本权重在下一轮迭代时应该降低.正确分类样本权重保持不变.目标领域样本错分时下一轮迭代权重增加,正确分类时权重保持不变.在TrAdaboost算法中,源领域样本错分时权重衰减过快[21-22],且模型融合时仅融合了后一半模型,没有充分利用源领域信息.考虑到本文中源领域样本和目标领域样本较大的相关性,为了解决权重衰减过快的问题,本文以目标领域样本加权错误率和源领域样本加权错误率为基础,设定了权重缩放因子.当目标领域加权错误率低时,认为模型表现良好,减慢源领域样本权重更新速度,反之亦然.
最终算法结构图如图3、4所示,算法流程如下:
(1) LGOS过采样算法.
(a) 初始化.设置训练批次大小为m,初始化编码器E、生成器G、判别器D和分类器C 4个网络参数.设置超参数γ1=0.01,γ2=1,γ3=0.02.
(b) 从真实数据中随机抽取批次大小为m的训练数据x,并输入编码器E后得到隐变量近似后验分布z.
(f) 汇总各网络损失,误差反向传播更新网络参数.
(g) 重复执行步骤(2)~(6),更新网络参数直至收敛.从隐空间采样输入生成器G得到最终生成样本.
(2) TrWSBoost集成分类算法.
(e) 当前步数t自增1,未到N步时重复步骤(9)~(11).
图3 本文LGOS算法流程图Fig.3 Flowchart of proposed LGOS algorithm
图4 训练TrWSBoost集成分类器流程图Fig.4 Flowchart of training of TrWSBoost ensemble classifier
3 实验与分析
通过对比实验分析本文所提出的LGOS算法和TrWSBoost算法的性能,以随机过采样(ROS)和传统的基于插值的过采样算法如SMOTE[5],Borderline-SMOTE[6],ADASYN[7],MWMOTE[8]作为比较对象.首先,比较LGOS算法和其他过采样算法生成数据的分布差异;其次,在加州大学欧文分校6个数据集(UCI)上详尽比较了各过采样算法生成数据训练的分类器的性能;最后,把过采样生成的数据作为迁移学习源领域样本,原始数据作为目标领域样本,对TrWSBoost算法和TrAdaboost算法进行比较.
图5 LGOS 算法和其他过采样算法生成数据图形比较Fig.5 Visual comparison of synthetic data of LGOS and other oversampling methods
3.1 实验参数设置和评估指标
实验选取了6个UCI公开数据集,把某类或几类指定为少数类,其他类作为多数类人为制造不平衡数据集,各数据集描述见表1.He等[10]的实验结果显示当数据分布接近均衡时,分类器表现最好,故在本文中,对每一数据集均通过过采样方法使少数类和多数类均衡.在实验中,随机选取80%的数据为训练集,其余数据作为测试集,取10次实验结果平均值作为报告结果.第一阶段实验采用决策树分类器,采用基尼系数作为节点切分标准,叶子节点最少样本数设置为1.SMOTE、Borderline-SMOTE、ADASYN算法K近邻设置为5,对MWMOTE算法K1设为5,K2设为3,K3设为3,聚类簇合并阈值Cp设为3.第二阶段实验把过采样生成的少数类样本当作源领域样本,原训练集数据当作目标领域样本,利用TrWSBoost算法训练迁移学习分类器.为了与第一阶段结果对比,弱分类器同样选取决策树分类器,迭代步数设置为50,当分类精确度很高时提前终止以防止过拟合.
表1 各数据集特性描述Tab.1 Description of characteristics of datasets
数据分布不均衡时,全局准确率不能作为分类器评价指标[10],实验分别选取Recall、F-measure、G-mean及AUC[10]来评估过采样算法的有效性.Recall衡量分类器在召回单一类别上的能力,在少数类应重点关注的应用中,这个指标是关注的重点,F-measure表示分类器在预测单一类别上的完备性和准确性的均衡,G-mean表示分类器召回两个类别上的综合表现,AUC与数据分布无关,适合比较不同分类器的差异.
3.2 实验结果与分析
选取satimage数据集,用各过采样算法生成同样数量的少数类样本,利用t分布随机领域嵌入(TSNE)投影算法将数据降到两维进行图形可视化表示.
图5所示为LGOS 算法和其他过采样算法生成数据图形比较.从图中可见,ROS算法生成样本和原始数据中少数类重合,容易导致过拟合.SMOTE算法生成样本相比于原始样本差异小,而且有少部分生成样本落入多数类区域成为噪声,对分类器训练不利.Borderline-SMOTE、MWMOTE及ADASYN算法侧重边界区域少数类样本,这些样本容易受落入多数类区域的噪声影响,对应忽略的噪声较大的权重,导致生成更多噪声,而且容易导致边界混合.从图3(f)中可以看出,生成样本的分布区域基本都在原始少数类样本分布区域内,而且和原始样本的关联更小,说明本文所提出的LGOS算法能够准确估计出真实样本概率密度函数,生成样本时是在真实的概率密度函数上采样,不同于基于插值的方式,生成样本时利用了全局的概率分布,生成样本相比于原始样本差异更大,提供的信息更多.
在6个UCI公开数据集上进行对比实验,用过采样算法生成的样本和原始样本混合训练决策树分类器,在测试集上的Recall、F-measure、G-mean和AUC指标见表2,粗体表示最优值.
表2 基于数据过采样的决策树分类器指标Tab.2 Metrics of decision tree classifier based on data oversampling
由表2可知,过采样扩充数据集后的分类指标相比于原始数据集基本都有所上升,因为过采样降低了数据集的不平衡比.相比于ROS、SMOTE、Borderline-SMOTE、MWMOTE及ADASYN算法,LGOS算法在Recall指标上明显优于其他方法,说明LGOS算法生成的样本较原始原本差异更大,提供的新信息更多,提高了召回率,这点从图3也可看出.但LGOS算法生成样本差异大也会引入少量噪声,降低Precision,所以在F-measure指标上没有显现出优势.在综合评价指标G-mean和AUC上,LGOS有明显优势,说明过采样同时兼顾了多数类和少数类.
把过采样少数类样本当作源领域样本,原始样本当作目标领域样本,利用本文的TrWSBoost算法训练集成分类器,本文实验中选取决策树作为基分类器,分类器测试指标见表3,粗体表示最优值.其中ROS表示先用ROS过采样算法生成少数类样本,再用所有数据训练TrWSBoost分类器, 其余类同.TrAdaboost列表示用LGOS生成少数类样本,再用TrAdaboost算法训练集成分类器.
表3 基于数据过采样的迁移学习分类器指标Tab.3 Metrics of transfer learning classifier based on data oversampling
从表2、3中可知,集成后各指标相比于单分类器均有明显提升,表示集成方法是解决不平衡学习的一个好办法.LGOS算法生成样本在集成后在各指标上均超出了其他方法.TrWSBoost算法相比于TrAdaboost算法解决了权重衰减过快的问题.在本文研究中,由于源领域样本和目标领域样本极大的相关性,防止权重衰减过快具有合理性.
4 结语
现有的不平衡分类问题过采样方法均是基于样本间插值的方法,区别在于如何区分需要关注的少数类样本以及每个样本对应的生成样本数量.然而,这些方法均没有有效利用数据的概率密度分布函数,导致生成样本相比于原始样本差异小.基于这一观察以及最近几年深度网络生成模型显现出的优越性,本文提出了一种基于隐变量后验分布生成对抗网络的过采样方法,这一方法在隐空间中采样通过生成器得到生成样本,生成模型能够学习真实样本概率分布函数,故模型能够生成和原始少数类同分布的样本.在6个公开数据集上的对比实验结果及生成数据图形可视化分布均证明了LGOS算法的优越性.另外,提出了改进的基于实例的迁移学习方法,进一步提升了分类器的性能.接下来的工作可以从几方面展开:① 本文仅关注于二类分类问题,可以扩展到多类分类问题;② 改进深度网络处理离散变量的能力以适用于带名义变量的分类问题;③ 该方法在回归问题中的应用.