APP下载

基于生成对抗网络的花卉识别方法

2022-12-13崔艳荣卞珍怡高英宁

江苏农业科学 2022年22期
关键词:花卉残差注意力

崔艳荣, 卞珍怡, 高英宁

(长江大学计算机科学学院,湖北荆州 434023)

由于花卉种类繁多、结构复杂,花卉识别在计算机视觉和图像处理领域仍然是一个挑战。传统的花卉特征提取方法有GrabCut切割算法[1]、快速鲁棒特征(SURF)、局部二进制模式(LBP)[2]和灰度共生矩阵(GLCM)[3]等方法,存在费时费力、主观性强、模型泛化能力差且无法处理海量数据等问题。

近几年,深度学习在计算机视觉、自然语言处理等方面获得了重大突破,已广泛应用于图像识别领域[4]。林君宇等将多输入卷积神经网络和迁移学习应用到花卉识别领域,取得了95.3%的识别率[5]。吴丽娜等在LeNet-5网络模型基础上调整连接方式和池化操作,并使用随机梯度下降算法进行花卉识别,取得了96.5%的花卉识别率[6]。刘嘉政对Inception_v3模型进行深度迁移学习,对其结构进行微调,在自定义数据集上取得了93.73%的准确率[7]。关胤采用152层残差网络结构进行花卉识别,并结合迁移学习训练,取得了较好的识别效果[8]。Cao等采用基于残差网络和注意力网络的加权视觉注意力学习块进行花卉识别,在flowers17上取得了85.7%的识别率[9]。裴晓芳等将resnet18网络模型的全连接层替换为卷积层,融入了混合域注意力机制,采用Softmax进行花卉识别[10]。

现有深度学习模型都需要大量的数据进行训练,计算机视觉领域中用于数据增强并减少过拟合的传统采样方法包括旋转、裁剪、翻转、颜色转换等[11]。在很多情况下,这些方法生成的图像仅为原始数据的简单冗余副本。生成对抗网络(generative adversarial network,GAN)很好地解决了该问题,其模型中的生成器(G)和判别器(D)交替训练完成后,G可以生成大量高质量的模拟样本以进行数据增强,GAN广泛适用于图像超分辨率重建、人脸图像生成与复原、图像转换、视频预测等领域[12]。

目前,很少有人将GAN应用到花卉识别领域。本研究提出一种基于改进Wasserstein生成对抗网络[13](attention residual WGAN-GP,ARWGAN-GP)的花卉识别方法。使用残差网络构建G和D,解决了网络过深时出现的梯度消失问题,减小了模型计算量,G和D分别融入了注意力机制[14],快速有效地提取了花卉显著区域特性,且通过融合损失函数进一步优化GAN模型,生成高质量花卉样本,将判别器应用到花卉识别网络,使得花卉识别准确度显著提高。

1 改进Wasserstein生成对抗网络方法

1.1 生成对抗网络

GAN是一种生成模型,包括G和D,GAN旨在训练G合成模拟样本G(z)以混淆D,D试图区分生成样本和真实样本。G和D之间的最小-最大博弈目标函数如式(1):

(1)

其中:x采样于真实数据分布Pr(x);z采样于随机噪声分布Pg(z);D(G(z))为D判别输入为G生成数据的概率;D(x)为D判别输入为原始数据的概率。

使用交叉熵散度来度量不同样本间的距离,会导致GAN产生梯度消失问题。王怡斐等提出使用Wasserstein距离比较样本之间的差异性,改善了GAN梯度消失的缺陷,使得网络训练更加稳定[15]。式(2)为WGAN的优化目标函数。

L=-Ex~Pdata(x)[D(x)]+Ez~Pz(z)[D(G(z))]。

(2)

但WGAN对权重参数裁剪过于简单,又会导致梯度爆炸,生成的样本质量仍然不理想。Liu等提出了新的改进方法,采用梯度惩罚的方法进行权重优化,以达到加快网络训练且生成高质量样本的目的[16]。模型损失函数如式(3)所示:

(3)

LG=-Ez~Pz(z)[D(G(z))];

(4)

(5)

式(3)前2项为WGAN的优化目标函数;x为原始数据分布Pdata(x)的输入样本;z是采样于Pz(z)中的随机噪声;最后一项为梯度惩罚项;λ为梯度惩罚项参数;ε采样于标准均匀分布。

1.2 注意力机制

由于卷积神经网络只关注图像数据中的局部依赖性,在计算长距离特征时效率极低,传统的生成对抗网络可以捕获到图像中的纹理特性,但很难学习到图像中特定的结构和几何特征。在生成对抗网络中添加注意力机制,可以计算图像像素之间的相关性,并建立长距离依赖性,进一步提取到花卉样本的全局特征,生成的图像可以显示更多的细节。注意力机制原理如图1所示。

图1中X表示卷积后的特征图,将x输入到3个1×1卷积层来获得特征空间f(x)、g(x)、h(x),将f(x)和g(x)执行相应计算得到βji,如式(6)~(9):

Sij=f(xi)T⊗g(xj);

(6)

f(xi)=Wfxi;

(7)

g(xj)=Wgxj;

(8)

(9)

式中:f(x)为像素提取;Wf为f(x)的权重;g(x)为全局特征提取;Wg为g(x)的权重;⊗表示矩阵乘法;N为特征图数;βji表示注意力图;注意力机制输出层见式(10)(11):

(10)

h(xi)=Whxi。

(11)

式中:Wh是h(x)的权重。为使网络学习提取到特征图的局部和全局特征,将自注意力层Oi输出乘以系数λ并将其添加到特征图,获得注意力机制的最终输出yi。其中λ是一个可学习参数,初始值设为0。

yi=λOi+xi。

(12)

1.3 生成对抗网络模型

1.3.1 生成器 原始生成器结构为简单卷积神经网络,模型训练速度较快,但模型生成样本质量不好,会出现棋盘效应;且随着网络深度的增加,会出现梯度消失,使得网络无法训练。本研究使用残差网络来构建生成器,采用最近邻插值代替反卷积进行上采样操作,将上采样和残差网络融合在一起来解决该问题。上采样残差块如图2所示,输入样本经过批量归一化以加快模型训练速度,采用最近邻插值进行上采样,通过2层卷积提取特征;且在输入样本的同时经过最近邻插值法进行上采样,通过1层卷积提取特征,将2个特征图输出进行融合,得到上采样残差块的最终输出。

花卉图像背景复杂,存在大量噪声干扰,使得生成器生成的花卉样本效果较差。在生成器浅层网络中加入注意力机制,可以关注生成花卉样本的边缘区域特征,在深层网络中添加注意力机制,进一步合成花卉样本的纹理细节特征。本研究在生成器中加入注意力机制来提取有效花卉样本区域特征,进一步合成高质量的花卉样本。注意力机制结构如图3所示。

生成器输入采样于随机分布的128维噪声,通过全连接层转换为16 384维向量,经过维度转换大小变为(4,4,1 024)。通过5个上采样残差块进行上采样,将特征图大小依次扩大2倍,除第1层上采样残差块通道数不变,其他依次缩小为1/2,特征图大小变为(128,128,64)。在每个上采样残差块后依次添加1个注意力模块进一步提取样本特征,提升模拟样本的清晰度,注意力机制不更改样本大小。最后通过1层卷积层,得到一个维度为(128,128,3)的模拟样本。卷积层激活函数为ReLU,输出层激活函数为Tanh。图4为G结构图。

1.3.2 判别器 判别器模型结构和生成器模型结构对应,采用下采样残差块进行特征提取,融入注意力机制进一步提取花卉区域样本特征,将维度为(128,128,3)的真实样本和模拟样本传入判别器,通过5层下采样残差块进行特征提取,使得特征图数不断增加,图片大小不断减小。在每层下采样残差块后依次添加1层注意力模块进行特征提取,约束模拟样本的细节特征,提高模拟样本的真实性,且注意力机制不改变特征图大小。最后通过卷积层得到(4,4,1 024)的特征图,通过全连接层进行判断。D中卷积层均为Leaky ReLU激活函数。图5为D结构图,图6为下采样残差块结构图。

1.3.3 损失函数及模型训练 为使得G可以生成清晰度更高的,且具有多样性的高质量花卉样本,生成器采用融合损失函数,将对抗损失、注意力损失和重构损失进行加权融合。判别器损失函数采用式(3)计算。

1.3.3.1 对抗损失 对抗损失为wgan-gp的生成器损失函数。如式(4)所示,改善了GAN和WGAN训练时出现的梯度消失,训练解决不稳定和生成花卉样本效果不佳的缺陷。

1.3.3.2 注意力损失 为更好地提取花卉样本的局部和全局性特征,生成纹理清晰、视觉上和真实样本高度相似且具有多样性的模拟样本,引入注意力损失,如式(13)所示。

(13)

式中:yi表示注意力机制输出,同式(12);θi表示注意力机制输出层的权重,浅层的注意力层输出可用信息较少,权重较小,深层输出权重较大,经对比试验验证,权重参数依次选为1,1,1,2,2,G(z)为生成模拟样本。

1.3.3.3 重构损失 重构损失为生成样本与真实花卉样本之间的L1距离,可以较好地反映生成花卉样本的真实性,如式(14)所示。

Lrec=Ex~Pdata(x),z~Pz(z)[‖G(z)-x‖1]。

(14)

式中:x为原始数据分布Pdata(x)的输入样本;z是采样于Pz(z)中的随机噪声。

融合目标损失函数为式(15)所示。

Llos=δ1LG+δ2Latt+δ3Lrec。

(15)

式中:δ1,δ2,δ3为损失函数的权重。经对比试验分析得到,δ1为1,δ2为0.05,δ3为10时效果最好。

G的训练需要固定D参数,随机噪声经过生成器进行一系列的上采样后生成模拟样本,将其送入到D进行判别,尽最大可能使D判别生成的样本为真实样本。D需要送入生成样本和真实样本进行参数优化,根据式(15)和式(3)计算生成器融合损失值和D的损失值,采用Adam算法进行参数调整,融合损失函数值主要为引导生成器生成更高质量的样本,D损失函数值可以表现网络模型的训练情况,当该值趋于稳定收敛时,表明网络模型训练近似达到最优,此时生成器加权损失函数也趋于稳定,生成的模拟样本质量更高。交替对抗训练G和D,为防止过拟合,加快模型收敛,G和D训练次数设为1 ∶k。

1.4 花卉识别模型

ARWGAN-GP训练完成后,G可以生成纹理清晰,视觉上和真实样本高度相似且具有多样性的模拟样本,判别器可以快速提取花卉样本特征。将训练好的生成对抗网络模型进行调整,以解决花卉识别准确度低的问题。图7为花卉识别网络模型。本研究迁移判别器网络参数到花卉识别网络,大幅度减小了花卉识别网络训练时间,且进一步提高了花卉识别率,替换全连接层为新设计的全连接分类层,使用softmax激活函数进行花卉识别。对花卉识别模型进行适当的参数调整以适应新任务的要求,使用交叉熵损失函数和Adam算法调整网络参数,采用生成器生成的模拟样本作为训练集训练花卉识别网络。

2 结果与分析

2.1 试验环境与数据集

本研究试验平台为Windows10,GPU为NVIDIA GEFORCE GTX 1080,深度学习架构为keras和Tensorflow。选择Oxford 102花卉数据集作为数据样本,包含102种花卉,共8 189张图片,将花卉样本等比例缩放为128×128像素,示例如图8所示。训练集和测试集的比例设置为9 ∶1。

2.2 试验设计

2.2.1 ARWGAN-GP模型训练及验证 本研究使用oxford102花卉数据集训练ARWGAN-GP,迭代次数为20 000,批处理样本数为32,G和D学习率分别为0.000 1和0.000 4,G和D优化更新次数为1 ∶3。使用G为每张花卉数据对应生成大量模拟样本作为训练集,训练本研究的花卉识别网络。

图9为ARWGAN-GP在不同迭代次数时判别器损失函数值。在模型开始训练阶段,D损失函数值震荡幅度较大。此时,G生成样本能力较弱,融合损失函数值和D损失函数值不断引导G生成更高质量的样本,经过多次迭代后,D损失函数值震荡范围缩小,下降到较小值且趋于收敛,表明此阶段为模型学习阶段。随着试验的进行,模型不断学习优化,当训练次数达到10 000次时,D损失函数值趋于稳定收敛,表明ARWGAN-GP得到了充分的训练,模型已经达到最优。此时,G可以生成高质量的模拟样本。训练完成后,使用G生成大量模拟花卉样本。

为验证本研究生成的对抗网络结构和融合损失函数的有效性,设置以下对比试验进行验证。试验1、2、3均采用WGAN-GP模型,试验1网络结构以本研究生成器结构为基础,去掉注意力机制,并采用反卷积神经网络代替上采样残差块结构。试验2网络结构以本研究生成器结构为基础,并去掉注意力机制,试验4为本研究生成对抗网络模型,试验3和试验4均使用本研究生成器结构。判别器结构均与生成器相对应。生成花卉样本如图10所示。

图10表明模型训练完成后,生成器可以生成纹理清晰、视觉上和真实样本高度相似且具有多样性的模拟样本。

本研究采用PSNR(峰值信噪比)、SSIM(结构相似性)和损失函数来对生成样本质量进行评价,PSNR值越大表明生成样本的质量越好,SSIM值越大表明生成样本的视觉效果越好。表1为PSNR和SSIM评估值。

表1 生成样本质量评估

图11为4组试验的损失函数图。

由图10、图11和表1可看出,试验1在迭代到 12 500 次时,模型损失函数趋于稳定收敛,生成的花卉样本存在部分模糊情况,这是由于生成对抗网络训练并没有充分学习到花卉样本特征,PSNR值为24.48 dB,SSIM为0.788 2。试验2相较于试验1模型收敛速度加快,表明使用上采样残差块加快了模型训练速度,且提高了模型特征提取能力,使得生成对抗网络生成样本能力得到进一步提升,PSNR值为25.74 dB,SSIM值为0.816 4,生成的花卉样本目标边缘更加清晰,视觉效果较好,质量更高。试验3在试验2基础上又加入了注意力机制,进一步关注有效花卉区域样本特征,使得生成的花卉样本纹理理细节更加清晰,PSNR为26.89 dB,SSIM为0.834 7。试验4使用改进的融合损失函数,使得网络进一步关注有效花卉区域,网络模型训练更加稳定,得到更高的PSNR和SSIM,生成花卉样本纹理更清晰,视觉效果更好,质量更高,进一步说明本研究生成对抗网络结构和融合损失函数的有效性。

2.2.2 花卉识别网络训练及生成样本评估 花卉识别网络使用Adam优化器调整模型参数,迭代次数为5 000,学习率为0.001,批处理样本数为64,使用原始训练集训练花卉识别网络。花卉识别网络识别准确度如图12所示。当网络迭代到3 000次时,花卉识别率趋于稳定,达到92.49%,网络达到最优状态。

为测试生成器生成样本的数量对花卉识别率的影响,设计了6组对比试验,使用训练完成的生成器为每张花卉数据对应生成50、60、70、80、90、100张模拟样本作为训练集训练本研究的花卉识别网络。试验结果如图13所示。

由图13可以看出,使用生成样本作为训练集使得准确率得到了很大提升,表明ARWGAN-GP模型生成的样本纹理清晰、视觉上和真实样本高度相似且具有多样性模拟样本的有效性。随着生成模拟样本数量的增多,对花卉数据集的增强效果逐渐趋于稳定,当花卉样本数达到80张时,花卉识别率逐渐趋于稳定,达到98.36%,此时模型已经处于收敛状态。

为验证本研究生成花卉样本进行数据增强和花卉识别网络的有效性,分别设置3组花卉识别网络和6组数据集进行试验验证。采用传统方法对原始数据集进行随机裁剪、旋转、缩放、偏移,等比例放大80倍,数据集设为D1,使用 “2.2.1”节4组试验训练完成后生成的样本数据,分别对应生成80张花卉样本,分别设为数据集D2、D3、D4、D5,花卉识别网络分别采用“2.2.1”节的试验1、试验2、试验4的判别器结构,并对最后的全连接层进行修改,花卉识别网络分别设为Conv、DownRes、VaDownRes。试验结果如表2所示。

表2 不同条件下花卉识别率

由表2可知,在不采用数据增强时,在3个分类网络上花卉识别平均准确率为91.21%,在D1数据集进行训练得到了92.75%的平均花卉识别率,而在D2数据集上进行训练则取得了95.14%的平均花卉识别率,相较于前2组数据集有较大提高。这是由于CNN对于旋转、缩放、偏移、裁剪等存在相应的不变性,在采用裁剪、旋转、缩放、偏移进行数据增强时,部分生成的样本数据和真实样本特性相同,仅仅是对真实数据的简单复制,生成的模拟样本数据多样性不足,使得网络识别效果不理想。而生成对抗网络进行训练时,生成器和判别器通过交替训练不断学习花卉样本特性,不断拟合花卉数据,当模型训练完成后,生成器可以生成纹理清晰、视觉上和真实样本高度相似且具有多样性的模拟样本,大幅度提高了花卉识别准确度。对比试验分析得到,在D5数据集上训练得到的花卉识别率要高于在D2、D3、D4数据集上训练得到的结果,表明本研究生成的对抗网络结构和融合损失函数具有有效性,进一步说明采用生成对抗网络生成模拟花卉样本可有效进行数据增强。

由表2可以看出,在6个花卉数据集上,DownRes模型的平均花卉识别率为94.70%,高于在Conv模型上的平均花卉识别率93.15%,表明使用下采样残差块构建花卉识别网络相较卷积神经网络大幅度提高了花卉特征提取能力,进一步说明花卉识别网络采用下采样残差块提取花卉样本特征更高效。在花卉识别网络融入注意力机制后,VaDownRes模型的平均花卉识别率得到了较大提高,进一步说明融入注意力机制后,使得花卉显著区域特征提取能力得到提高,大幅度提高了花卉的识别准确率。

2.2.3 花卉识别方法对比试验 设置以下试验验证本研究方法的有效性。

试验1:文献[17]提出使用CNN来进行花卉识别,与传统的花卉识别方法不同,该方法使用CNN自动学习花卉样本特性。

试验2:文献[18]提出在CNN添加注意力机制进行花卉识别,使用CNN自动提取样本特征,通过注意力机制进一步提取深度特征。

试验3:采用文献[19]提出的方法,利用预训练模型resnet50在花卉图像上进行迁移微调,重新构建新的分类层,在本研究原始数据集上进行重新训练。

试验4:采用文献[9]提出的方法,以resnet50为基础框架构建基于注意力机制驱动的残差网络,并通过全局平均池化和全连接层实现花卉分类,在本研究原始数据集上重训练。

试验5:使用花卉数据集训练ARWGAN-GP,训练结束后使用生成器网络进行数据增强,且迁移D参数到花卉识别网络,对其参数微调,使用增强数据重新训练花卉识别网络模型。

不同试验下花卉识别准确度如表3所示。

表3 不同试验下花卉识别准确度

由表3可知,试验1基于CNN进行自动提取花卉特征可以达到83.00%的准确度。试验2在CNN的基础上添加注意力机制,相比单独使用CNN进行花卉识别,该方法利用注意力机制融合花卉样本的局部和全局特征,进一步学习捕获到深度花卉特征,在一定程度上提高了准确率。试验3使用深度残差网络进行花卉识别,相比使用CNN提高了花卉识别准确度,这是由于为了提高网络的识别率,需要增强网络深度,但这会导致梯度消失,而残差网络改善了该缺陷,残差网络更容易优化,收敛更快且准确度更高。试验4在深度残差神经网络的基础上加入了注意力机制,相比试验3提高了花卉识别率,加入注意力机制后,可以有效提取花卉显著区域特征,减小噪声干扰,增强了网络的学习能力,使得准确度更高。试验5采用本研究提出的花卉识别网络模型,相比前4组试验,该方法更进一步提高了花卉识别准确度,这是由于前4组试验的数据量偏小,很难达到较好的收敛效果。而本研究采用残差网络和注意力机制构建生成对抗网络,并使用融合损失函数,使得生成对抗网络充分提取到了花卉样本特征,使用训练结束的ARWGAN-GP模型进行数据增强,使得样本得到了有效扩充,且迁移D参数到花卉识别网络,加快了花卉识别网络模型的收敛速度,使用生成数据进行训练花卉识别网络,进一步提高了模型的识别率。

3 结束语

本研究提出了一种基于改进生成对抗网络的花卉识别方法。使用残差网络构建生成器和判别器,解决了网络深度加深时出现的梯度消失和训练不稳定问题,使得网络收敛更快;融入了注意力机制,可以快速有效地提取花卉显著区域特征,减小了噪声干扰,且改进了损失函数,进一步提高生成对抗网络的能力;ARWGAN-GP训练结束后,采用生成器进行数据增强,迁移判别器参数到花卉识别模型,并进行参数微调,加快了模型的收敛速度,进一步提高了模型的识别准确度。

猜你喜欢

花卉残差注意力
李鱓·花卉十二开
基于双向GRU与残差拟合的车辆跟驰建模
让注意力“飞”回来
三招搞定花卉病虫害
基于残差学习的自适应无人机目标跟踪算法
《花卉之二》
基于递归残差网络的图像超分辨率重建
水晶泥花卉栽培技术
“扬眼”APP:让注意力“变现”
平稳自相关过程的残差累积和控制图