改进RetinaNet的刺梨果实图像识别*
2021-04-09闫建伟张乐伟赵源张富贵
闫建伟,张乐伟,赵源,张富贵
(1.贵州大学机械工程学院,贵阳市,550025;2.国家林业和草原局刺梨工程技术研究中心,贵阳市,550025)
0 引言
近年来,随着深度学习理论研究的不断深入,基于Anchor目标检测框架已成为国内外卷积神经网络领域研究的热点。当前,对基于Anchor的目标检测框架的研究主要分为两类:一类是基于区域候选框的二阶段目标检测框架(two stage)算法,该方法先在图像上生成若干可能包含目标的候选区域,然后通过卷积神经网络(Convolutional Neural Network,CNN)分别对这些候选区域提取特征,最后通过卷积神经网络进行目标位置的回归与类别的识别,典型的算法有R-CNN(Region Convolutional Neural Network)[1]、SPPNet(Spatial Pyramid Pooling Networks)[2]、Fast RCNN[3-4]、Faster RCNN[5-6]、FPN(Feature Pyramid Networks)[7]、R-FCN(Region-based Fully Convolutional Network)[8]等;另一类为一阶段目标检测框架(one stage)算法,该方法直接从图片获得预测结果,将整个目标检测任务整合成一个端到端的任务,而且只处理一次图像即可得到目标的类别与位置信息,典型的算法有SSD(Single Shot multibox Detector)[9]、YOLO(You Only Look Once)[10-12]、DSSD(Deconvolutional Single Shot Detector)[13]、FSSD(Feature fusion Single Shot multibox Detector)[14]以及RetinaNet[15]等。二阶段目标检测算法比一阶段目标检测算法具有更高的准确率和定位精度,而一阶目标检测算法RetinaNet在COCO测试集上的结果高于二阶目标检测模型[15]。在RetinaNet目标检测算法方面,宋欢欢等[16]将其网络层数增加到152层,并且加入了MobileNet的设计思想,对其加速和压缩,有效地提高了准确率。刘革等[17]为了提高模型的前向推断速度,用MobileNet V3替换ResNet-50[18]用于基础特征提取网络。张物华等[19]在基础特征提取网络中加入特征通道注意力机制模块,突出特征图中的特征通道,以提高精度。王璐璐等[20]在C3、C4层加入通道注意力模块,同时,为缓解网络的过拟合问题,在通道注意力模块的全连接层加入随机失活机制,从而增强网络的鲁棒性。谢学立等[21]在RetinaNet结构中分别添加bottom-up短连接通路以及全局上下文上采样模块,用来增强检测层特征的结构性和语义性。以上改进虽然准确率有一定提高,但检测效率却显著降低了。
针对生产车间刺梨果实识别,人工分拣分级效率低,无法满足工业化加工刺梨果实的要求,本文拟选择一阶目标检测算法中的RetinaNet目标检测算法,以RetinaNet模型为基础,改进bias公式以及运用K-means++聚类算法,并增强数据和合理调节参数,以期实现对刺梨加工车间的果实进行高精度、快速识别。
1 数据采集与处理
1.1 数据采集
本文刺梨果实图像于2019年9月28日在贵州省龙里县谷脚镇茶香村刺梨产业示范园区采集,品种为贵龙5号,共采集图片807幅。对刺梨果实用尼康(Nikon)D750单反相机进行拍照,原始图像格式为.JPG,分辨率为6 016像素×4 016像素。刺梨果实图像采集样本示例如图1所示。
图1 刺梨果实图像样本示例
1.2 数据集样本及标签制作
本文从拍摄到的807幅刺梨果实照片中,将刺梨果实分为6类。通过ACDSee20软件将807幅大小为6 016像素×4 016像素的原图裁剪为多幅大小为902像素×602像素的完全包含刺梨果实的样本,对裁剪后的样本进行上下翻转以及旋转45°、90°和270°,最终得到7 426 幅刺梨样本。再使用LabelImg软件对7 426 幅刺梨样本进行样本标签制作。
1.3 刺梨果实分类
刺梨果实图像分级简图,如图2所示。
1.1
针对采摘后的刺梨果实进行分级,按颜色、果实好坏等情况,将刺梨果实图像分为6类:1.1、1.2、2.1、2.2、3.1、3.2;其中,1.、2.、3.等按照颜色不同进行分级[1.:颜色为青色、2.:颜色为金黄、3.:非以上两种情况]。.1、.2等按照果实好坏进行分级[.1:非坏果、.2:坏果]。分类后各类刺梨照片数量较均衡,有利于后期处理。
2 网络模型RetinaNet的改进
RetinaNet模型由特征提取网络、特征金字塔网络、子网络等三个模块构成,其网络结构如图3所示。图3中A表示特征提取网络,使用深度残差网络ResNet来完成对图像特征的初步提取;B表示特征金字塔网络,将A中产生的特征图进行重新组合,完成对图像特征的精细化提取,以便能更好地表达图像信息;C表示子网络,用于对待检测的目标图像分类和定位。
图3 RetinaNet的网络结构
2.1 偏差bias的改进
由于RetinaNet的核心是Focal Loss,在Focal Loss中,用于分类卷积的bias,可以在训练的初始阶段提高positive的分类概率以及决定神经云产生的正负激励的难易程度。针对其无法准确取值,在原有式(1)的基础进行了改进,改进后的计算公式如式(2)所示。
bias=log[(1-π)/π]
(1)
bias=αlog[(1-π)/π]β+γ
(2)
α、β、γ可以控制bias的取值,根据实际情况,最终得出α=1.0、β=1.1、γ=0.0、π=0.01,使得预测图像目标的准确性上升。
2.2 K-means++聚类算法
Anchor机制可有效解决目标检测任务中存在的尺度及宽高比例变化范围过大等问题。由于原始RetinaNet使用的是非刺梨样本的数据集,所以原始RetinaNet所选定的Anchor尺度和宽高比例在本文的检测任务中并不适用。
本文运用K-means++聚类算法[22],使其更加适合刺梨样本,定位框更加精准。通过对刺梨数据集的真实标注框进行聚类操作,真实标注框长宽映射到模型输入大小下的聚类结果如图4所示。
图4 真实boxes长宽聚类值
由图4可知,有三个聚类簇,刺梨的宽高聚集在[35,33]、[40,39]以及[45,44]附近。
因此,本文将[90×90,125×125,160×160,195×195,230×230]作为对应的5个特征层的Anchor尺寸,以[0.5,1.0,1.5]作为Anchor的长宽比。
3 网络模型训练步骤
改进后的卷积神经网络模型,对刺梨果实进行识别的训练步骤如图5所示。
待训练的刺梨果实图片,首先在特征提取网络图5(a)中由深度残差网络ResNet50来完成对图像特征的初步提取;其次在特征金字塔网络图5(b)中,将图5(a)中产生的特征图进行重新组合,以便能更好地表达图像信息;最后在子网络图5(c)、图5(d)中,运用K-means++聚类算法优化Anchor参数,以及对其中的bias公式进行改进,使其分类和定位更加准确。
图5 卷积神经网络模型改进后的训练步骤
4 试验与结果分析
4.1 软件及硬件
电脑配置:Windows 10、64位操作系统。笔记本电脑,GeForce GTX 1050 Ti 显卡,8 G显存;Intel(R)Core(TM)i5-8300H处理器,主频2.30 GHz,磁盘内存128 GB,编程语言是Python编程语言。
从7 426幅刺梨样本中,选出90%即6 683幅刺梨样本进行训练,余下10%即743幅刺梨样本进行最终检测。采用RetinaNet算法,在Keras框架下,并且设置该模型的batch-size为1、epochs为50、steps为1 000。
4.2 结果分析
4.2.1 准确率和损失率对比
样本识别准确率Acc的计算如式(3)所示,即预测正确的样本比例。
(3)
式中:TP——正样本被正确识别为正样本;
TN——负样本被正确识别为负样本;
N——测试的样本数。
改进前后RetinaNet目标检测算法在不同训练轮次的准确率和损失率如图6、图7所示。
图6 原始RetinaNet目标检测算法在不同轮次的平均Acc与loss曲线
图7 改进后RetinaNet目标检测算法在不同轮次的平均Acc与loss曲线
由图6、图7可知,由于改进了RetinaNet目标检测算法的核心部分Focal Loss中的bias公式,针对刺梨果实的图像识别,改进的RetinaNet目标检测算法训练集、测试集的准确率都在90%以上,相对于原始RetinaNet目标检测算法,训练集、测试集的准确率均提高1.80%;训练集损失率与测试集损失率的收敛趋势相同,训练集、测试集的损失率降低了1.27%。可见,改进的RetinaNet目标检测算法对刺梨果实的图像识别具有较高的识别率。
4.3.2 标记框对比
随机选取一张未经训练的刺梨果实照片(像素大小:902×602)如图8所示,分别在原始RetinaNet目标检测算法与改进RetinaNet目标检测算法进行识别,识别效果(只保留置信度为80%以上的识别框)如图9、图10所示。
图8 未经训练照片
图9 原始RetinaNet目标检测算法识别效果
图10 改进后RetinaNet目标检测算法识别效果
由图9、图10可知,改进后的RetinaNet目标检测算法相对于原始RetinaNet目标检测算法有较好的效果:可以使Anchor尺寸更加接近真实值,从而降低模型的训练难度;识别准确率有不同程度提高;在识别准确率80%以上时,可以检测出更多的刺梨果实;原始RetinaNet目标检测算法中错误的标记框不再出现。
4.3.3 6种不同刺梨果实分级对比
在未经过训练的刺梨样本中,按照6种不同刺梨果实分级方式,随机各选取出1种,裁剪拼接成像素大小为902×602的图片,如图11所示,将其分别在原始RetinaNet目标检测算法与改进RetinaNet目标检测算法进行识别,识别结果如图12、图13所示。
图11 6种刺梨果实拼接
图12 原始RetinaNet目标检测算法识别效果
图13 改进后RetinaNet目标检测算法识别效果
从未经训练的588幅刺梨果实样本中随机选取若干照片,分别在原始RetinaNet目标检测算法和改进后的RetinaNet目标检测算法中进行分类识别。图片像素对检测时间有一定的影响,提供检测的单幅照片像素为300×300;含有单个刺梨果实。6类刺梨果实对比情况如表1所示。
改进前后6类刺梨果实识别准确率及检测时间对比如表1所示,改进后的RetinaNet目标检测算法对6类刺梨果实的识别准确率均有提高,提高的幅度从0.14%、0.68%、1.32%、1.83%、2.60%到4.21%不等,识别准确率最高提高了4.21%,识别准确率平均提高了1.80%。
表1 改进前后6类刺梨果实识别准确率及检测时间对比
单个刺梨果实检测时间为由60.99 ms缩减到57.91 ms,降低了3.08 ms,与原始RetinaNet目标检测算法检测时间相比缩短了5.05%。
5 结论
1)本文针对原始RetinaNet目标检测算法进行了改进,通过改进RetinaNet框架中Focal loss的bias公式、运用维度聚类算法找出Anchor的较好尺寸来改进原始的RetinaNet目标检测算法。本文训练出来的识别模型对加工车间下的刺梨果实准确率较高,能够为刺梨果实的快速识别奠定基础。
2)通过改进,与原始RetinaNet目标检测算法相比,本文改进RetinaNet算法使标记框更加准确;识别准确率更高,最高提升了4.21%,平均提高了1.80%。单幅单个刺梨果实检测时间由60.99 ms缩减到57.91 ms,降低了3.08 ms。本文改进RetinaNet算法平均识别准确率均有不同程度提高,检测时间均有不同程度降低。
3)本文改进RetinaNet目标检测算法,为工业生产刺梨加工车间的刺梨果实快速识别提供参考。