APP下载

改进关系网络的小样本图像分类方法

2023-01-16王光博

沈阳理工大学学报 2023年1期
关键词:准确度卷积分类

王光博,陈 亮

(沈阳理工大学 自动化与电气工程学院,沈阳 110159)

随着科学技术的持续发展,深度学习[1]作为机器学习的分支已经在计算机视觉领域中取得了巨大的进步,且基于神经网络的图像分类技术取得了快速发展[2]。由于深度学习模型的参数多,训练一般需要依靠大量的监督数据,且有些真实样本因涉及隐私、安全问题无法收集,例如在金融、军事[3]和医学领域,由于无法获得足够多的训练数据,模型训练往往达不到预期效果。因此,需要研究小样本图像分类任务[4],让模型通过每类少数标签样本获得分类的能力。

小样本学习降低了获取数据集的难度。在小样本学习领域常用的四种方法有数据增强、迁移学习[5]、正则化和元学习[6]。利用数据增强、迁移学习或正则化的方法并不能从根本上解决小样本中过拟合的现象。元学习的主要目的是让机器学会学习,利用任务之间的共性,使模型从少量标签样本中进行算法学习,确保元学习器能够快速习得解决新学习任务的能力。Santoro等[7]提出采用长短时记忆(LSTM)网络解决单样本学习问题,Shyam等[8]提出采用循环神经网络进行样本间的动态比较。度量学习[9]可以看作元学习的一种形式,基于度量的元学习采用不同的度量方法对样本进行相似性度量,其中比较有代表性的研究成果有Sung等[10]提出的具有高效度量方式的关系网络等。张碧陶等[11]提出了融合强化学习和关系网络的小样本算法,使改进的网络结构获得了良好的分类效果。魏胜楠等[12]在关系网络中引入自注意力机制,使模型能够提取每个类别特定的信息,提高了分类准确率。王年等[13]在关系网络中加入了感受野块,增强了网络的度量能力。

本文在关系网络的基础上提出一种新的小样本图像分类方法。该方法改进原关系网络的结构,将嵌入模型的第3个卷积块替换成inception块[14],增强嵌入模块的特征提取能力;替换原关系网络中用于相关性计算的激活函数,实现更好的信息流动,有利于模型的训练;为便于模型训练,改进原关系网络的损失函数,使模型具有更好的泛化能力,有效提高小样本图像分类的准确度。

1 算法及其改进

1.1 问题描述

元学习通过学习已知的类别,完成新测试任务的分类。数据集分成训练集、验证集以及测试集,且属于不同的标签域。在训练阶段,将数据集分成多个元任务,随机抽取C种类别,每种类别中有K个标记样本,由C×K个样本构成支持集,C种类别剩余的样本构成查询集,目的是使模型在C×K个样本里学会区分C种类别,称为C-wayK-shot问题。如果K等于1,称为单样本学习,如果K大于1,则为小样本学习。本文将选择5-way 1-shot和5-way 5-shot两种方式在三个小样本检测任务常用的数据集Omniglot、MiniImagenet以及TieredImageNet上完成模型的训练和测试。

1.2 关系网络

关系网络包括嵌入模块和关系模块,其结构如图1所示。

图1 关系网络结构框图

嵌入模块的功能是对输入图片的特征信息进行提取,该模块由4个卷积层和2个最大池化层构成;关系模块的主要作用是计算支持样本与查询样本的相似度,由两个卷积层和两个全连接层构成,最后利用Sigmoid函数计算查询样本和支持样本的相似程度,从而判断两幅图像是否属于同一类别,计算表达式为

式中:ri,j为支持集样本与查询集样本的关系得分;m表示查询集样本的数量;φ、φ分别为嵌入参数和关系参数;fφ是嵌入模块的映射函数;gφ是关系模块的映射函数;C(·)为提取到的查询样本和支持样本的特征;yi和yj分别表示查询样本和支持样本的标签;I(·)是指数函数,如果支持集样本和查询集样本是不同类别,值为0,如果支持集样本和查询集样本是相同类别,值为1;argmin表示通过最小化均方误差损失优化嵌入模型和关系模型的参数。

1.3 改进关系网络

本文基于改进关系网络的小样本学习模型结构框图如图2所示,改进了原关系网络的嵌入模块和关系模块。

图2 改进关系网络结构框图

本文将原关系网络嵌入模块的第3个卷积块替换成inception块,增强嵌入模块的特征提取能力。关系模块采用3个卷积块和2个全连接层,通过缩放函数输出关系得分。在原结构的基础上增加一个卷积块的目的是为了能够将特征信息进一步卷积,并将第3个卷积层采用全局平均池化处理,避免过拟合现象的发生,通过Mish全连接层激活,最后一个全连接层使用Sigmoid和缩放函数计算查询样本和支持样本的相似程度。因为缩放函数能够使输出的特征向量维持在一个特定的范围之内,进而降低特征向量的影响程度。采用缩放函数加快了梯度的收敛速度,更换原关系网络中用于相关性计算的激活函数和损失函数,以有利于网络模型训练。通过对查询集图像和支持集图像的特征向量进行相关性计算得到相似度分数,分数最高的类即为预测的分类。

1.3.1 inception块

inception块的设计思想是用并联的方法使不同的卷积层进行组合,经过卷积层处理的结果矩阵进行拼接,形成一个更深的矩阵。inception块是让网络决定需要什么样卷积层以及是否需要池化操作,能够对一些大尺寸的矩阵进行降维操作,进而降低计算量,能够在不同尺寸上聚合图像信息,从不同尺度上提取特征。本文将原关系网络嵌入模块中的第三个卷积块替换成inception块,能够有效地改进网络的深度和宽度,提升模型的准确率,避免过拟合现象。

本文采用的inception块共有三部分,结构如图3所示。为降低通道的个数,第一部分采用1×1的卷积,并进一步减少计算量,再通过2个3×3的卷积将感受野进行放大,在减少计算量的同时不会降低网络的性能。为获得更多的表层信息,第二部分采用了常规3×3大小的卷积,同时可以保留更多的纹理信息,第三部分采用3×3最大池化层提取不同尺度的特征。最后将这三部分得到的不同特征图拼接在一起,得到多尺度特征。结合三个部分得到的不同特征映射,能够增强嵌入模块的特征表达能力。

图3 inception块结构图

1.3.2 激活函数

关系网络中的激活函数采用ReLU函数。ReLU函数为分段性函数,函数正值的收敛比较快,负值的梯度为0,会出现对应的参数不更新的情况。Mish激活函数的形式为

两种激活函数的比较如图4所示。

由图4可见,Mish激活函数是连续的光滑函数,避免了ReLU激活函数的奇异点,有更好的泛化能力和优化能力;Mish函数曲线是上无边界和有下边界,上无边界防止了梯度饱和的情况发生,有下边界与ReLU激活函数的硬零边界不同,会保留较小的负值,能够稳定网络梯度流,从而实现更好的信息流动。因此,本文采用Mish取代Re-LU作为激活函数。

图4 激活函数对比图

1.3.3 损失函数

关系网络采用的损失函数是均方误差MSE。MSE损失函数的公式为

式中:yi是输入样本的真实值;f(xi)是预测值。如果输入样本真实值和预测值之差比1小,误差会变得更小,如果输入样本真实值和预测值之差比1大,误差会变得更大,因此MSE损失函数的缺点是对离群点敏感。

本文改进网络的损失函数采用修正后的平均绝对误差SmoothL1,其公式为

式中f'(xi)是预测值。

图5为两个损失函数的对比图。

图5 损失函数对比图

从图5可以看出,均方误差损失函数的特点是连续光滑,同时函数上的每一个点都可导,便于网络模型更好地收敛;本文网络模型采用的损失函数与均方误差损失函数比较,存在对离群点不敏感的优势,无论差值多大,其惩罚都不变,有着稳定的梯度,不会出现梯度爆炸的现象,便于模型更好地训练。

1.4 模型训练的基本原理

数据集划分为训练集Dtrain、支持集Dsupport以及测试集Dtest,Dtrain对网络进行元训练,Dtest对网络的泛化性能进行测试。在元训练期间,从训练集中随机选取一些样本作为支持集,将余下的样本组成查询集,并且Dsupport和Dtest是一样的标签空间,Dtrain和Dtest无交集。

本文的网络整体模型如图6所示。首先通过损失函数训练嵌入模块,然后查询集样本和支持集样本分别进入嵌入模块进行特征提取得到特征向量,确定嵌入网络参数。将提取到的特征向量进行组合后输入关系模块进行计算,使用损失函数进行训练,得到相似度分数。当查询集样本和支持集样本是不同类别时,相似度分数接近0,当查询集样本和支持集样本是同类别时,相似度分数接近1。网络参数的训练过程为

图6 网络整体模型

网络模型训练的过程就是模拟小样本分类的场景。对比标签yi和网络模型的相似度分数ri,j,通过累加求和过程获得最终的损失值。

2 实验结果与分析

2.1 数据集

本文改进的关系网络模型在Omniglot、Mini-Imagenet和TieredImageNet三个小样本任务常用的数据集上完成实验。Omniglot数据集[15]由不同人绘制的字符组成,总计1 623个类别的字符;MiniImageNet数据集包含100种类别的60 000张图片,由600张图像构成一个类;TieredImageNet数据集是数量较大的小样本学习数据集,一共包含608个类,总计有779 165张图像,比MiniImageNet数据集中的类别有更大的语义差距,从而提供了更严格的泛化测试。

2.2 实验设置

实验条件:Intel(R)Core(TM)i5-10300H,2.50 GHz、16 GB内存,NVIDIA GeForceRTX2060显卡,Windows10操作系统,基于Pytorch深度学习框架。

各数据集的设置说明如表1所示。

表1 各数据集的设置说明

将Omniglot数据集划分成三个部分,其中训练集由1 200类图像组成,验证集由123类图像组成,测试集由300类图像组成,实验结果由测试集中随机生成的1 000个批次的分类精确度平均值表示;在MiniImageNet数据集中,训练集由64类图像组成,验证集由16类图像组成,测试集由20类图像组成,实验结果由测试集中随机生成的1 000个批次的分类精确度平均值表示;TieredImageNet数据集中的训练集由351类图像组成,验证集由97类图像组成,测试集由160类图像组成,最终实验结果由测试集中随机生成的600个批次的分类精确度平均值表示。为使分类的效果更好,将数据集图像采用翻转方式对其进行扩充,以达到数据增强的效果。

2.3 实验结果

将改进后的关系网络在Omniglot数据集上的运行结果与原关系网络进行比较,改进前后模型的准确度如表2所示。

表2 改进前后模型在Omniglot数据集上的准确度

本文的网络模型在Omniglot数据集的5-way 5-shot上的准确度为99.8%±0.32%,对比原关系网络提升的效果并不明显,但是在5-way 1-shot上的准确度为99.7%±0.32%,比原关系网络大约提高了0.1%。

将改进后的关系网络在MiniImageNet数据集上的运行结果与原关系网络进行比较,如表3所示。

表3 改进前后模型在MiniImageNet数据集上的准确度

本文的网络模型在5-way 1-shot上的准确度为54.24%±0.79%,比原关系网络模型的准确度提高了3.8%,在5-way 5-shot上的准确度为69.05%±0.71%,比关系网络提高了3.73%。

本文改进的网络模型和原关系网络在Mini-ImageNet数据集中的5-way 1-shot和5-way 5-shot任务上的迭代次数和分类准确度如图7所示。

图7 两种情况下的分类准确度

由图9可见,在20 000次之前的训练中,图像分类的准确度在不断增加,在40 000次之后的训练中,分类准确度的变化不大且基本保持平稳。本文的改进方法在在5-way 1-shot和5-way 5-shot两种情况下都表现出更好的性能。

将本文改进的网络在TieredImageNet数据集上的运行结果与原关系网络比较,如表4所示。

本文网络模型在5-way 1-shot任务上的准确度为58.69%,比关系网络提高了4.21%,在5-way 5-shot上的准确度为75.36%,比关系网络提高了4.05%,两种情况下都表现出优异的性能。

2.4 实验分析

2.4.1 网络结构分析

为进一步验证引入inception块对模型图像分类准确度的影响,在MiniImageNet数据集上分别在模型没有引入inception块和引入inception块的两种情况下进行实验,实验结果如表5所示。

表5 网络结构分析

由表5可知,在5-way 1-shot情况下,引入inception块的分类准确度提高1.17%,在5-way 5-shot情况下提高1.03%。证实了引入inception块的网络能够增强模型的分类准确度。

2.4.2 鲁棒性分析

为进一步验证本文改进网络具有的鲁棒性,在保证验证集和测试集不变的情况下,在MinImagenet数据集上的5-way 1-shot和5-way 5-shot实验中,对模型分类准确度随测试集类别数变化的情况进行比较。在MinImagenet数据集中以10类为间隔,从100类到10类依次改变测试集的类别数,两种情况下的分类准确度如图8所示。从图8可以看出,在MinImagenet数据集上,随着测试集中类别数量的减少,模型的分类准确度逐渐降低,但模型依然可以保持75%以上的分类准确度,表明本文改进模型的鲁棒性明显优于原关系网络模型。

图8 两种情况下的分类准确度

2.4.3 评估实验

为验证数据集多样性对于本文改进网络的影响,在Omniglot和MinImagenet数据集上进行实验,结果如图9所示。图9给出了模型分类准确率随数据集多样性的变化,经过对比分析可以发现:数据集多样性越高,分类准确率越低;在同样类别的情况下,样本数越多,准确率越高;在样本数相同的情况下,类别数越多,准确率越低。

图9 C-way K-shot任务

3 结论

本文在原关系网络的基础结构上,引入了inception块增加网络的宽度,更换了原关系网络模型中的损失函数和激活函数,有利于网络训练,同时保留了网络结构的简单性以及快速的训练和测试过程。在Omniglot数据集、MiniImageNet数据集和TieredImageNet数据集上的实验结果表明,改进的关系网络比原关系网络的准确率有所提高,可以提升小样本学习的泛化能力。

猜你喜欢

准确度卷积分类
基于3D-Winograd的快速卷积算法设计及FPGA实现
分类算一算
卷积神经网络的分析与设计
从滤波器理解卷积
分类讨论求坐标
幕墙用挂件安装准确度控制技术
数据分析中的分类讨论
基于傅里叶域卷积表示的目标跟踪算法
教你一招:数的分类
动态汽车衡准确度等级的现实意义