APP下载

混合高斯变分自编码器的聚类网络

2022-07-15陈华华陈哲郭春生应娜叶学义

中国图象图形学报 2022年7期
关键词:编码器高斯聚类

陈华华,陈哲,郭春生,应娜,叶学义

杭州电子科技大学通信工程学院,杭州 310018

0 引 言

随着科学技术的发展,数据的规模呈指数型增长,这些海量数据往往蕴含着许多极具价值的潜在信息,如何捕获并挖掘这些隐藏的潜在信息是一个亟需解决的问题,具有重要的现实意义。

聚类分析是数据挖掘领域的一种研究方法,也是目前用于发现数据间潜在信息的主要方法之一。聚类分析旨在发现数据间潜在的关系,并根据数据的特征等将数据聚为不同的类别,也称为簇,使得簇内的数据具有较小的差异性,簇间的数据具有较大的差异性。聚类分析已广泛应用于用户画像(张海涛 等,2018)、协同过滤推荐系统(吴湖 等,2010)、基因分析(岳峰 等,2008)、异常检测(成宝芝 等,2017)和文本聚类(路荣 等,2012)等领域,吸引了越来越多的学者加入到聚类分析的研究队伍中。

1 相关工作

MacQueen(1967)提出的K-means聚类算法和Ester等人(1996)提出的DBSCAN(density-based spatial clustering of applications with noise)聚类算法是两个经典的聚类方法。K-means算法的基本思想是将数据分为不同的类别,每一个类别具有一个聚类中心,根据数据到各聚类中心的距离更新优化聚类划分,反复迭代得到一个最佳的聚类结果。DBSCAN是一种基于密度的聚类算法,根据邻域内的点数判断该区域是否属于密集区域并形成临时聚簇,将相连的临时聚簇进行合并得到更合理的聚簇分配。

然而,随着数据维度的提高,经典的聚类算法在应对高维数据时往往存在维数灾难等问题,使得计算成本大幅增加并且效果不佳。近年来,深度神经网络得到了飞速的发展。因此,越来越多的研究人员将目光转向利用深度学习进行聚类分析。相比于传统方法,深度学习能够更好地得到数据特征的低维表示,提高聚类效果。Yang等人(2017)提出了DCN(deep clustering network)模型,该模型训练一个自编码器AE(autoencoder),得到数据特征的低维表示,然后结合K-means聚类算法,将重构损失和K-means聚类损失进行联合优化,该模型聚类效果高于同时期的传统算法。Xie等人(2016)提出了DEC(deep embedded clustering)模型,模型训练一个堆叠的自编码器得到数据特征的低维表示,然后在此基础上构建聚类网络。Opochinsky等人(2020)和Chazan等人(2019)用多个自编码器建立深度聚类网络,每个自编码器表示一个簇,取得了更好的聚类效果。Duan等人(2019)采用自编码生成一个深度嵌入网用于数据降维,学习一个softmax自编码器用于估计簇的数目,获得了较好的实验结果。但是自编码器的目的主要是为了降维,网络训练目标是使解码器输出和输入尽可能逼近,当训练样本与预测样本不符合相同分布时,提取的特征往往比较差。与自编码器相关的另一种编码器是变分自编码器(variational autoencoder,VAE),Jiang等人(2017)和Lim等人(2020)提出了变分自编码器用于聚类,先用VAE生成隐层特征,然后用混合高斯分布拟合隐层特征,聚类效果优于经典的聚类方法和一些生成式聚类方法。变分自编码器采用标准正态分布作为先验,容易导致后验塌陷(Guo等,2020),对不同类别数据的分布不能较好地逼近,影响编码和解码结果。为此,本文引入混合高斯分布作为先验,构建混合高斯自编码器生成隐层特征,学习数据的特征分布,并以此编码器与聚类层结合形成聚类网络,通过优化编码器隐层特征的软分配分布与软分配概率辅助目标分布之间的KL散度(Kullback-Leibler divergence)对聚类网络进行训练。在基准数据集MNIST(Modified National Institute of Standards and Technology Database)和Fashion-MNIST上的实验结果表明,本文网络取得了较好的聚类效果,且优于当前多种流行的聚类方法。

2 聚类模型

2.1 混合高斯变分自编码器

变分自编码器(Kingma和Welling,2014)是深度学习领域一种近似推理的有向模型,它利用变分推断与深度学习相结合,能够学习到数据结构化的特征表示或分布,是生成式建模中的一种重要方法。标准变分自编码器的目标函数可表示为

L(θ,φ;x)=-DKL[qφ(z|x)‖pθ(z)]+
Eqφ(z|x)[logpθ(x|z)]

(1)

式中,当近似后验分布与假设先验分布之间的KL散度DKL[qφ(z|x)‖pθ(z)]最小时,目标函数L(θ,φ;x)下限最大,从而使模型达到最优。其中,E表示求期望,θ和φ分别是先验分布p和后验分布q的参数,x和z分别是变分自编码器的输入和隐层特征。但是,标准变分自编码器中采用标准正态分布作为先验,可能引起后验塌陷(Guo等,2020),并且容易忽略一些潜在的变量约束,导致对不同类别数据的分布不能较好地逼近,影响自编码网络的编码和解码结果。在这里,本文引入混合高斯分布作为先验,构建变分自编码器。混合高斯先验可表示为

(2)

(3)

(4)

(5)

式中,ωi是第i个单高斯分布的系数,μ(i)和σ(i)2代表第i个单高斯分布的均值和方差。式(1)中等号右侧的第1项是混合高斯分布近似后验与先验之间的KL散度,目前尚无高效的算法可以获得它的解析解。为此,Hershey和Olsen(2007)提出了一种近似解法,采用变分推断获得KL 散度的上界。于是最小化KL 散度转换为最小化其近似上界,该算法简述如下:

设数据x的混合高斯分布f(x)和g(x)分别可表示为

(6)

(7)

(8)

式中,fi(x)和gi(x)分别代表f(x)和g(x)中等号右侧第i个单高斯分布。

因此,式(1)中等号右侧第1项可以表示为

(9)

式(1)中等号右侧的第2项是重构项,它的计算方式与标准变分自编码器类似,具体为

(10)

式中,L是采样的数量,zl的下标l表示第l次采样。由式(9)和式(10)可得到混合高斯分布后验与先验的变分下界,即

(11)

式(11)即为混合高斯分布自编码器的目标函数。

2.2 聚类网络

本文在混合高斯变分自编码器的基础上,使用编码器部分作为数据空间和特征空间之间的初始映射,将编码器和聚类层组合成聚类网络,如图1所示。图1中,以学习获得的混合高斯分布自编码器部分作为聚类网络的编码器部分,结合聚类层通过最小化辅助目标分布和软分配分布之间的KL散度学习聚类网络。具体过程如下:

图1 聚类网络结构图Fig.1 Structure of the clustering network

假设存在数据集{x(1),x(2),…,x(N)},N为数据样本的数量,输入数据x(i)到混合高斯变分自编码器中得到隐层特征z(i),使用欧氏距离计算隐层特征z(i)到聚类中心c(t)的距离,c(t)表示第t个聚类中心,并使用t分布(van der Maaten和Hinton,2008)衡量隐层特征z(i)到聚类中心c(t)之间的相似度sit,也即特征z(i)分配到聚类类别t的软分配概率(Dempster等,1977),sit计算为

(12)

式中,χ代表t分布的自由度,由于聚类属于无监督学习,无法交叉验证χ的取值,因此本文取χ=1。将数据输入到训练获得的混合高斯变分自编码器得到隐层特征,然后在特征空间采用混合高斯模型进行聚类,得到T个聚类中心{c(1),c(2),…,c(T)}作为初始化的聚类中心。

本文采用最小化辅助目标分布和软分配之间的KL散度实现对模型的优化,因此辅助目标分布的选择对聚类效果至关重要。Xie等人(2016)提出辅助目标分布应该更加注重高置信度的数据点,以提高聚类的准确性。同时,需要归一化代价函数对每个聚类中心的贡献,防止出现过大的聚簇导致隐层的特征空间扭曲。因此,将软分配概率的辅助目标分布pit定义为

(13)

由此,得到聚类层的损失Lcluster,定义为软分配分布s与辅助目标分布p之间的KL散度,采用随机梯度下降法优化损失函数。具体为

(14)

2.3 网络结构

本文中的变分自编码器采用卷积神经网络实现,实现的网络结构如图1所示。网络的编码器部分,首先是核大小为3×3、步长为2、通道数为64的卷积层,激活函数采用ReLU(rectified linear unit)函数。考虑到池化层在实现下采样时存在丢失有用信息的不足(Sabour等,2017),本文采用核大小为3×3、步长为2的卷积层实现下采样以保留重要信息。然后级联大小为3 × 3、步长为2、通道数为128的卷积层,再级联大小为3 × 3、步长为2、通道数为256的卷积层,然后连接两个全连接层,维数分别是2 304和10。

解码器部分在结构上与编码器部分是对称的,首先是两个级联的全连接层,它们的维数分别是10和2 304,然后级联一个维数变形层(reshape层),将数据维度从2 304转换为3 × 3 × 256,然后级联3个卷积核大小为3 × 3的反卷积层,它们分别具有128通道、64通道、1通道。与编码器类似,解码器采用步长为2的反卷积层实现上采样。解码器在最后一层的反卷积层使用sigmoid函数作为激活函数,其余卷积层、反卷积层的激活函数都采用ReLU函数。

聚类网络中,自编码器的结构直接采用混合高斯分布自编码器的编码部分,并以学习获得的编码器参数作为初始值,进一步按式(14)所示目标函数优化学习聚类网络。

3 实验结果分析

3.1 实验配置

实验采用的计算机配置如下:Intel(R)Core(TM)i5-7300HQ@2.50 GHz CPU,Windows 10,编译器Python3.6,内存为8 GB,编程环境为TensorFlow和Keras,编程语言为Python。网络训练参数值采用正态分布随机初始化,batch size为100,优化方法为Adam优化器,学习率为0.000 1。

为评估本文方法的有效性,采用了聚类分析中常用的MNIST(LeCun等,1998)数据集和Fashion-MNIST(Xiao等,2017)数据集分别进行实验评测。MNIST数据集是LeCun等人(1998)在美国国家标准与技术研究院提供的手写数据集的基础上筛选,并进行了尺寸标准化及数字中心化等处理的标准数据集,由60 000个训练样本和10 000个测试样本组成,每个样本都是28 × 28像素的灰度图像;Fashion-MNIST是一个替代MNIST手写数字集的图像数据集,由德国科技公司Zalando旗下的研究部门提供,涵盖了10种类别共70 000个不同时尚商品的正面灰度图像,图像大小为28 × 28像素,包括T恤、裤子、套衫、裙子、外套、凉鞋、衬衫、运动鞋、包和靴子等10类商品。为简化问题求解,实验中使用的混合高斯中高斯分量的个数M= 10,各分量的混合系数取经验值为1/10,初始化聚类中心的个数T= 10。

3.2 评价指标

采用聚类精度(accuracy,ACC)和标准互信息(normalized mutual information,NMI)作为评估指标。

聚类精度ACC用于衡量算法得到的聚类标签准确性,计算为

(15)

标准化互信息NMI是衡量两个随机事件之间相关性的重要指标,也是常用的聚类评估指标之一,这里用做衡量聚类标签与真实数据类别标签的契合程度。标准化互信息计算为

(16)

3.3 实验结果

本文模型分别在MNIST和Fashion-MNIST数据集上进行了聚类实验。为了进一步验证本文方法的有效性,与高斯混合模型(Gaussian mixture model,GMM)(Fraley和Raftery,1998)、VAE+K-means(Kingma和Welling,2014)、DEC(deep embedded clustering)(Xie等,2016)、IDEC(improved DEC)(Guo等,2017)、GMVAE(Dilokthanakul等,2017)、KADC (K-autoencoders deep clustering)(Opochinsky等,2020)、VaDE(variational deep embedding)(Jiang等,2017)、DCVA(deep clustering with VAE)(Lim等,2020)和ClusterGAN(clustering in generative adversarial networks)(Mukherjee等,2019)等算法进行对比,对比结果如表1和表2所示。

从表1和表2可以看出,本文模型与GMM、VAE + K-means、DEC、IDEC、GMVAE、VaDE和DCVA相比,在MNIST和Fashion-MNIST数据集上的聚类精度ACC和标准互信息NMI均有较大提升。与KADC相比,本文方法在Fashion-MNIST数据集上的标准互信息NMI低于KADC。与ClusterGAN算法相比,本文算法在MNST数据集上的聚类ACC和标准互信息NMI优于ClusterGAN,但在Fashion-MNIST数据集上略逊于ClusterGAN算法。除了GMM和ClusterGAN方法,其他方法都是从AE或VAE基础上发展起来的,其中性能最好的VaDE方法对聚簇采用单个质心表示嵌入空间中的向量,而本文方法采用一个混合高斯自动编码器网络表示嵌入空间中的向量,这样对每个聚簇可以实现更丰富的表示,同时辅助目标分布的引入提高了聚类的准确性,防止了隐层的特征空间扭曲,使得算法具有良好的聚类表现;ClusterGAN和本文方法虽然都属于生成式聚类方法,但是两者无论是从网络结构还是实现思路上属于两个不同的方法类别,ClusterGAN是目前聚类性能最好的生成对抗网络,本文方法与该方法的性能相当。总体而言,本文算法具有较好的聚类结果。

表1 不同算法在MNIST和Fashion-MNIST数据集上的ACC比较Table 1 Comparison of ACC among different algorithms on MNIST and Fashion-MNIST datasets /%

表2 不同算法在MNIST和Fashion-MNIST数据集上的NMI比较Table 2 Comparison of NMI among different algorithms on MNIST and Fashion-MNIST datasets /%

此外,由表1和表2还可知,尽管MNIST和Fashion-MNIST都是28 × 28像素的高维图像数据,具有相同的数据维度,但是无论是ACC指标还是NMI指标,各方法除了GMM方法在改变数据集时指标下降比较小,其他方法都出现了大幅度下降,这是因为MNIST数据集由灰度变化范围小的图像组成,纹理特征信息比较单一,主要是字符的边界信息,在描述数据特征时对各模型的特征表述能力要求相对较低,而Fashion-MNIST数据集由灰度变化范围大的图像组成,纹理特征信息比较丰富,如衣服、鞋子和包等丰富的内部纹理和边缘,在描述数据特征时对各模型的特征表述能力要求较高,而目前的方法对Fashion-MNIST数据集的特征信息表达均不是很强,导致聚类性能明显劣于MNIST数据集。

图2列出了本文模型在MNIST和Fashion-MNIST数据集中每个类的10幅图像,其中每一行对应一个聚簇。由图2(a)可知,本文模型对MNIST数据集的聚类结果较为准确,但出现了若干“4”和“9”混淆的情况,这一结果与“4”和“9”的外观特征相似有关。由图2(b)可知,Fashion-MNIST数据集下的每个聚类依次为凉鞋、外套、靴子、套衫、裤子、运动鞋、T恤、衬衫、包和裙子。聚类结果中的凉鞋、靴子、裤子、运动鞋、T恤衫、包和裙子等类别的聚类较为准确,在外套、套衫和衬衫这3类中出现了若干次混淆的情况,这一结果的产生与这3类物体的外观较为相似有关,区分这3类物品更依赖于内部纹理特征的差异。

图2 MNIST和Fashion-MNIST数据集的聚类结果Fig.2 Clustering results on MNIST and Fashion-MNIST datasets ((a) MNIST;(b) Fashion-MNIST)

图3列出了本文模型与ClusterGAN算法在Fashion-MNIST数据集下的重建结果。由图3可知,ClusterGAN算法重建得到图像的纹理特征较本文模型清晰,其在纹理特征的提取与重建上优于本文方法。因此,ClusterGAN的聚类效果在Fashion-MNIST数据集上略优于本文算法。

图3 本文模型与ClusterGAN算法的重建结果Fig.3 Reconstruction of the proposed network and ClusterGAN((a) ours;(b) ClusterGAN)

3.4 结构复杂度分析

本文模型在Fashion-MNIST数据集上的聚类指标略低于ClusterGAN算法。本文模型和ClusterGAN算法都属于生成式聚类方法,不同的是ClusterGAN采用生成对抗网络用于聚类,而本文采用混合高斯分布的变分自编码器。对这两种算法的模型参数量进行了比较分析,如表3所示。

表3 本文模型与ClusterGAN算法的模型参数量比较Table 3 Comparison of the number of parameters in the proposed network and ClusterGAN

由表3可知,本文模型参数量不及ClusterGAN算法的1/10,远小于ClusterGAN算法,是一个更轻量级的网络模型,这使得本文模型占用更小的存储空间,降低了对内存的需求,同时能够实现更快的运行速度,在Fashion-MNIST数据集上的性能差异小于2%。

4 结 论

本文提出了一种基于混合高斯变分自编码器的聚类网络模型,以混合高斯分布为先验建立变分自编码器,学习数据的特征分布,然后将编码器与聚类层结合构建聚类网络,采用编码器隐层特征的软分配分布与软分配概率辅助目标分布之间的KL散度作为目标函数,对网络进行训练和优化。在基准数据集MNIST和Fashion-MNIST上进行了评价和比较,对比实验结果表明,采用混合高斯自动编码器网络对每个聚簇可以实现更丰富的表示,辅助目标分布的引入使得算法具有良好的聚类表现,使得本文方法在聚类ACC和标准互信息NMI指标都优于当前的一些聚类算法,取得了较好的聚类效果。

但是本文算法也存在两个不足:1)在模型建立上,先验和后验中的高斯分量个数设为相同,虽然实验结果优于当前的一些聚类算法,但是处理方法只是对实际情况的简化处理,与实际情况存在一些差别,实际更一般的情况中,先验和后验中的高斯分量个数并不相同,如何优化求解这个问题是个难题;同时,本文中各高斯分量的混合系数不是通过数学优化求得的最佳混合系数,而是根据经验设定为等概率混合的,缺乏完善的理论支持。2)在模型对信息的表达能力上,当处理复杂纹理特征时,纹理特征的重建效果有待提高。上述问题将是下一步研究的重点。

猜你喜欢

编码器高斯聚类
基于ResNet18特征编码器的水稻病虫害图像描述生成
基于数据降维与聚类的车联网数据分析应用
基于模糊聚类和支持向量回归的成绩预测
数学王子高斯
基于TMS320F28335的绝对式光电编码器驱动设计
基于密度的自适应搜索增量聚类法
从自卑到自信 瑞恩·高斯林
具备DV解码功能的DVD编码器——数字视频刻录应用的理想选择