APP下载

基于Boosting的深度学习图像分类算法设计

2021-12-18王志强王先传

系统仿真技术 2021年2期
关键词:网络结构残差准确率

王志强,王先传

(1.阜阳工业经济学校,安徽阜阳 236015;2.阜阳师范大学,安徽阜阳 236041)

目前深度学习应用较为广泛[1-2],其是对人类大脑中的神经网络进行模拟,从而实现对于输入数据(图像、文本等)的处理。因此,随着深度学习的不断发展以及所需处理的数据量的增大,可以看出在进行图像分类及目标识别的应用中,深度学习的应用至关重要,非线性网络变得更加有利,可以从输入数据中得到更多有利信息[3-4]。针对这个情况提出ResNet网络,在网络结构中引入快捷方式使输入同时引入至输出,最终学习输出与输入之间的差值,构成残差结构,使得每一次仅用学习残差的部分,成功解决退化问题,使ResNet网络的应用比较广泛。

1 ResNet网络结构及Boosting工作机制

1.1 ResNet网络结构

ResNet网络[5-6]是在之前的VGG网络结构基础上添加快捷方式,这样每次学习的便仅仅是残差的部分。即每个ResNet网络都由许多残差块堆积而成,其中每一个残差块都由一个神经网络模型和一个快捷方式构成。一个残差块的结构如图1所示。

图1 残差块的结构示意图Fig.1 Structure diagram of residual block

第t层残差块的输出为

其中,x为网络的输入,ft(·)代表第t层的卷积层,gt(x)代表t-1层的最终输出。

1.2 Boosting工作机制

Boosting的工作流程大致为:对初始提供的数据集进行训练,得到一个学习器。根据学习器的结果表现,相应地对数据进行处理,即对于做错的数据,改变它相应的权重,使得其在后续的训练中得到更多的关注。Boosting的工作机制如图2所示。

图2 Boosting的工作机制Fig.2 Working mechanism of Boosting

2 基于Boosting的图像分类算法设计与实现

2.1 设计算法的优点

在ResNet网络每一层的输出处引入一个线性的全连接层,将得到的结果通过全连接层引出来[7],此时相当于将提取好的特征拿出来,在后续的网络中便不再对已经分类好的数据进行处理。即在每一层中将分类效果很好的类别先提取出来,随后将剩下的再次进行训练测试,直到达到较好的结果,这样的好处在于使得训练的时间变短、结构简单,在后续的网络中所要训练提取的特征变少,准确率得到提升。

2.2 网络构建的设计

已知的ResNet网络共有5种结构,分别为ResNet18、ResNet34、ResNet50、ResNet101、ResNet152。其都有自己的网络构架,大体上都是统一的,只是在具体的层中有所不同。

利用Boosting的思想可以对网络结构进行改进,具体构建方法为[8]:在conv_1、conv_2、conv_3、conv_4四个层的输出处分别引入一个线性的全连接层,利用Python中的nn.linear()函数实现,得到4个全连接层结果为y1、y2、y3、y4。在引入四个全连接层时,需要利用a.reshape()函数先对结果进行处理,改变所得结果的矩阵形状,使之与全连接层的参数保持匹配,随后得到全连接层的结果,如图3所示。

图3 基于Boosting的图像分类算法设计结构图Fig.3 Design structure chart of image classification algorithm based on Boosting

2.3 网络构建的设计测试

2.3.1 Loss的训练方法及变化

网络的Loss即在定义了损失函数Criterion之后,计算网络输出与真实标签之间的误差,从而得到损失值Loss。在研究中希望Loss越小越好,故在得到Loss后,需要进行反向传播,改变相应的参数,使得网络的Loss不断减小,直到达到标准或是完成所有的训练次数。本文使用pytorch完成实验验证,则通过Loss.backward()函数完成误差的反向传播,pytorch的内在机制可以实现自动求导得到每个参数的梯度。故在实现网络结构的重新构建后,对网络进行训练时,根据Loss需要作出相应的改变。

2.3.2 预测值的变化

在设计分类网络训练及测试过程中,预测值不再采用原来的网络输出Outputs的最大值,而是应用Boosting的思想,需要采用网络的最终输出Outputs以及每一个引出的全连接层的输出y1,y2,y3,y4的和,并通过torch.max()函数得到Outputs+y1+y2+y3+y4的最大值,以此作为网络的预测值。

3 网络构建的设计实现

3.1 Cifar类数据集

Cifar10数据集[9-10]包含10大类,共有60000张彩色图像。在Cifar数据集的官网上,提供3种形式的数据集,一种供python程序使用,一种供Matlab程序使用,还有一种供C程序使用。本文采用的是python程序,则选择python版本的数据集,下载解压后得到的数据集存在5个批次的训练数据集和1个批次的测试数据集。每一个批次都是用python中的cpickle库打包好的形式,此时每个批处理文件都包含一个字典。在进行程序运行时,可直接进行调用。

与Cifar10相比,Cifar100数据集中的图像共分为100类。每一类包含600张图像,其中500张作为训练图像,100张作为测试图像。故Cifar100共有50000张训练图像,10000张测试图像,与Cifar10相同,只是图像的种类变多了。

Cifar100数据集首先被分为20个超类,即比较笼统的分类,随后每个超类下精细地分为5个小类。这样数据集中所有图像都具有一个“粗糙的”标签和一个“精细的”标签。

3.2 数据集为Cifar10、网络为ResNet-Boosting的输出

对基础网络的改进方式为[11-12]:在每一个层的输出处引入一个线性的全连接层,每一个链接层的结果都可以看做与一个弱分类器对应,最终将所有的弱分类器进行结合,得到最终的输出。

实验结果如图4,每训练一个epoch便进行一次测试,得到实时准确率。

图4 Cifar10-ResNet-Boosting的最终输出结果Fig.4 The final output of Cifar10-ResNet-Boosting

将所有准确率用折线图表示,结果如图5所示。

图5 Cifar10-ResNet-Boosting网络的准确率Fig.5 The accuracy of Cifar10-ResNet-Boosting network

由图5可以清楚地看出,随着遍历数据集次数的增加,准确率先上升,随后不再明显上升,而是在90%的附近波动,如表1所示。

表1 Cifar10数据集下的结果对比Tab.1 Comparison of results in Cifar10 data set

由此看出,设计的网络对分类准确率的提升没有较大的作用,故提出出现这种现象的原因是否由于进行实验的数据集太小,仅仅拥有10类,如若改变数据集,增大数据的类别,是否对准确率有所提升,则进行下一组对比实验进行论证。

3.3 Cifar100-对应ResNet网络及ResNet-Boosting网络的训练及测试

在进行了上述的对比实验后,可以看出设计的网络对于图片分类的准确率没有较大的提升,与原来的网络没有较大的差异,则提出设置如下:改变数据集,再次进行对比实验,观察是否对准确率有所提升。

部分参数设置如下:

(1)Epoch=135,即程序将对输入的数据遍历135次。

(2)Batch_Size=125,即每次输入125张图片。由于Cifar100数据集有50000张训练图片,故遍历一次输入,输出400次Loss及Acc。但相对于Cifar10而言,其数据类别增多。

(3)LR=0.001,减小学习率,有助于观察更加细微的变化。数据集仍为Cifar100,网络仍为ResNet网络,未做任何改变。同样的每训练一个Epoch便进行一次测试,得到实时准确率。

则最终结果如图6所示,在进行了135次训练及测试后,最后一次网络的分类准确率为67%。

图6 Cifar100-ResNet网络的最终输出结果Fig.6 The final output of Cifar100-ResNet network

将所有实验结果通过折线图表示,结果如图7所示。

由图7可知,随着遍历数据集次数的增加,准确率先上升在某一处便不再明显上升,最终准确率的值在66%附近波动。

图7 Cifar100-ResNet网络的准确率Fig.7 The accuracy of Cifar100-ResNet network

实验结果如表2所示,通过设计的网络对Cifar100数据集训练及测试的结果,总共遍历数据集135次,准确率先上升,最后在70%附近波动。

表2 Cifar100数据集下的结果对比Tab.2 Comparison of results in Cifar100 data set

均采用数据集Cifar100,同样遍历数据集135次,学习率均为0.001,Batch_Size=125,在相同的参数情况下,仅仅是网络的不同,进行实验验证,以此探寻所设计的网络对准确率的提升是否有效果。

经过实验论证可以看出,在ResNet网络下,最终的准确率平衡在66%左右,而设计的网络准确率最终稳定在70%左右,Loss均逐渐减小。由此可以看出设计的网络对准确率的提升是有效果的,尽管提升不是很多。

由此印证了3.2小节的对比实验可能由于数据集类别过少,导致准确率没有任何提升效果,在更大的数据上分类准确率有了提升,虽然提升不是非常明显。

4 总 结

通过理论设想及实验证明可知,对ResNet网络结构进行改变,在每一层的输出处引入一个线性的全连接层,将每一个全连接层的结果看做对应于一个弱分类器,在训练网络的过程中,分别计算其与标签之间的Loss,最终的Loss是将得到的Loss进行相加。并且每训练一个Epoth便测试一下准确率。实验中设置的Epoth为135,可以看到通过不断的训练和测试,最终的准确率在不断上升。与此同时存在的不足是:在每一层的输出引入一个线性的全连接层,最后利用Boosting的思想时,仅仅是将所有结果进行简单的求和,考虑不是很全面,应该考虑更多的可能来实现基于Boosting的思想对ResNet网络结构的改变,得到更好的对ResNet网络训练的方法。同时在Cifar10的数据集上准确率基本上没有较大的变化,但在Cifar100数据集上有所提升,应该再对更多类别的数据集进行实验,观察是否会有更加明显的提升。

猜你喜欢

网络结构残差准确率
基于双向GRU与残差拟合的车辆跟驰建模
乳腺超声检查诊断乳腺肿瘤的特异度及准确率分析
不同序列磁共振成像诊断脊柱损伤的临床准确率比较探讨
2015—2017 年宁夏各天气预报参考产品质量检验分析
基于残差学习的自适应无人机目标跟踪算法
基于递归残差网络的图像超分辨率重建
高速公路车牌识别标识站准确率验证法
基于广义混合图的弱节点对等覆盖网络结构
体系作战信息流转超网络结构优化
基于互信息的贝叶斯网络结构学习