APP下载

基于迁移学习的中国蛇类识别研究

2022-05-16周志斌罗志聪张展榜孙奇燕

野生动物学报 2022年2期
关键词:蛇类集上准确度

周志斌 罗志聪,2* 张展榜 孙奇燕

(1.福建农林大学机电工程学院,福州,350002;2.福建省农业信息感知技术重点实验室,福建农林大学机电工程学院,福州,350002;3.福建农林大学计算机与信息学院,福州,350002)

建立一个能自动识别图像中蛇种类的强大系统是保护生物多样性和全球健康的重要目标[1]。根据世界健康组织(World Health Organization)调查结果显示,全球每年450万~540万人被蛇咬(https://www. who.int/health-topics/snakebite#tab=tab_1)。被蛇咬伤后,快速辨别蛇的种类,根据不同蛇毒进行特异性的抗体治疗是改善蛇咬伤健康数据的重要环节[2-3]。

蛇类图像分类问题的研究由于传统生物图像识别方法的特征工程复杂、计算机算力资源匮乏等局限,过去一直得不到发展[4-6]。随着深度学习在计算机视觉领域取得突破性进展,基于深度学习的图像分类技术被应用于各种分类研究[7-9],近两年来,国内外也开始将深度学习技术应用于蛇类图像的分类研究[10-17],利用卷积神经网络代替传统图像识别的特征工程,完成特征的提取和匹配,将蛇类图像输入网络中便能自动识别蛇的种类。

图像分类最主要的问题是提升分类的准确度,提升分类精度的一个最直接方法是增加数据集,通过在大量数据集上训练,提升神经网络模型的泛化能力,但是蛇的行踪隐蔽和危险性,导致蛇类照片不容易采集,且国内缺乏公开可用的大型蛇图像数据集。迁移学习的方法是解决小样本学习的有效办法[18],张皓洋[19]采用迁移学习的方法使用在ImageNet上预训练过的Inception-v3模型迁移到自己建立的数据集上,达到88%以上的分类精度。设计更加优秀的网络结构、改进网络模型也可以提升模型的分类精度,付永钦[20]受ResNet V2残差网络结构启发,以残差网络为基础提出BRC卷积神经网络。这种新型网络结构的重要组成部分为BRC块,由批标准化层、ReLU激活函数层和卷积层组成。

针对蛇类图像分类问题,搜集国内常见蛇种建立图像数据集,以常见的卷积神经网络分类模型进行图像识别试验,包括模型的设计和优化、迁移学习与微调,探索适用于国内蛇类图像识别的改进方案。

1 试验数据集

国内缺乏公开丰富的中国蛇类图像数据集[20],依靠个人采集建立数据集存在较大困难,因此通过网络收集蛇类图片的方式建立数据集。在不同搜索引擎和网络平台上下载目标蛇类的照片,根据每种蛇类的外形、斑纹和颜色等特征,剔除不符合要求的照片,建立包含6种国内常见蛇类的ChineseSnakes数据集,共1 427张图片,其中金环蛇(Bungarusfasciatus)153张、银环蛇(B.mult-icinctus)233张、竹叶青(Trimeresurusstejnegeri)284张、王锦蛇(Elaphecarinata)425张、圆斑蝰(Daboiarussellisiamensis)85张和尖吻蝮(Deinagkistrodonacutus)247张,6个类别分别放在6个文件夹中,以蛇的名称命名,以1∶9的比例分为测试集和试验集,再将其中试验集以8∶2的比例分为训练集和验证集。

2 蛇类图像分类模型

2.1 改进的神经网络模型

选择中大型网络模型VGG19[21]、ResNet50、ResNet101[22]以及轻量化模型MobileNetV2和Xception作为特征提取网络,对这5种典型神经网络模型采用相同分类网络设计,分类网络由平均池化层(global average pooling)、Dropout层和全连接层组成。具体步骤为:(1)全局平均池化操作以特征图(feature map)为单位均值化,将每层特征值合为一个值,代替卷积神经网络中传统的多层全连接层分类网络,可以有效降低参数数量。(2)由于数据集的样本量较少,为减少过拟合,在全局平均池化层后紧接着Dropout层,在训练过程中按照一定的概率将部分神经元暂时无效,减少中间特征的数量[9]。(3)最后设计匹配分类任务的全连接层,由于数据集中包含6种不同的蛇,故全连接层的输出通道设计为6个。

2.2 Adam优化算法

综合考虑学习速度和学习效果,训练时采用Adam优化算法[23]对参数的学习率动态调整。利用Adam优化算法计算更新的步长,在t时刻参数更新的计算公式为:

(1)

(2)

mt=β1mt-1+(1-β1)gt

(3)

(4)

(5)

式中:gt是t时刻的梯度,其计算如式(6)所示。

(6)

Adam优化算法的参数学习率设置为0.000 1,一阶矩和二阶矩估计的指数衰减因子beta1和beta2采用框架默认值0.9和0.999。通过试验,每个神经网络设置运行足够多的轮次(epoch),其训练精度曲线即可达到收敛。相较于图像数据集的数量,试验的神经网络模型网络复杂,表达能力过强,容易过拟合,分类网络的丢弃率dropout设置为0.5,可达到较好效果。

3 模型的训练

3.1 试验环境

试验基于Google Colab平台,它是一个Python开发环境,使用谷歌云在浏览器中运行。平台搭载的硬件环境为:12 GB RAM内存,Tesla V100 GPU和16 GB显存;平台使用Ubuntu 18.04.5 LTS操作系统,加载软件环境有CUDA 11.0.228、TensorFlow 2.4.1和Python 3.7.10。深度学习的神经网络模型以TensorFlow作为框架,调用keras的功能接口进行搭建。试验结果中的准确率曲线和损失曲线采用Python中的matplotlib可视化得到。

3.2 数据的预处理

3.2.1 在线数据增强扩充数据集

现实中蛇类的形态变化多样,拍摄时蛇的位置随机性大,根据这一特点,选择旋转、镜像和平移3种方式对图像进行数据增强,将这些操作重复应用于同一幅图像(图1)。每轮训练前,对准备输入网络的图像按批次旋转、平移和翻转变换。通过3种不同的数据增强方式,以及每个增强方式设置的随机因子,可以保证每轮训练的数据都是不一样的,有多少轮训练,数据集就扩充多少倍。

图1 数据增强过的蛇类图像Fig.1 Snake images with data augmentation

3.2.2 图像数据标准化

收集到的蛇类图像来源不同,图片尺寸大小和格式也不同,在输入神经网络模型前需要对图像数据进行标准化处理。将图像均转换为JPG格式,利用keras的image_dataset_from_directory函数以8∶2的比例设置训练集、验证集,然后图像尺寸统一调整为160像素×160像素,对数据集进行数据批处理,批样本数量(batch size)设置为128,并将调整好的图像重新缩放像数值,将像数值从[0,255]放缩到[-1,1]。

3.3 超参数的设置

为了使模型获得更好的效果,对模型的超参数进行设置。结合GPU并行运算能力,训练数据时批样本数量设置为128,每次训练进行9次迭代。训练采用Adam优化算法,初始学习率设置为0.000 1,而对模型微调时设置的初始学习率为0.000 01。训练的轮次以模型能达到收敛为标准,微调前训练300次,微调训练100次,对各个不同的神经网络模型采用统一的超参数设置。

3.4 迁移学习和微调

通过keras提供的功能接口,加载不包括顶部分类层网络的神经网络模型作为基础模型,与设计的分类网络构成新的神经网络模型。采用迁移学习策略[24],把在ImageNet上预训练获得的权值参数加载到网络中,作为特征提取网络,应用在新目标蛇类图像数据集上,提取蛇类图像的数据特征,减少训练深度学习的神经网络模型所需的数据量及试验平台的计算力,解决小型数据集在复杂神经网络结构上产生的过拟合现象[25]。图2为本研究的神经网络模型结构简图。在大多数卷积网络中,层级越高,其专用化程度越高。最初几层是非常简单和通用的功能,这些功能可以概括几乎所有类型的图像,越往上层,功能对于模型训练所依据的数据集越具体。为进一步提高神经网络在蛇类数据集上的分类效果,采用微调训练策略(fine-tune),将强制把权值从通用特征映射调整到与数据集特定关联的特征。只有训练好分类层网络,才能微调卷积网络的顶部卷积层,否则,初始时的训练损失值很大,会破坏掉微调之前卷积层学到的内容。在特征提取网络设置为不可训练的情况下,对分类层网络训练,最后解冻特征提取网络顶部的部分隐藏层,对这些不同的神经网络模型进行不同程度的微调训练,可更好地提升模型的分类效果。

图2 迁移学习的模型结构示意图Fig.2 Schematic diagram of transfer learning model structure

4 结果与分析

4.1 随机初始化下的重新学习

使用中大型网络模型ResNet50、ResNet101、VGG19以及轻量化模型MobileNetV2和Xception作为基础网络模型,采用随机初始化权值的方式在ChineseSnakes数据集上训练。表1是随机初始化的重新学习方式在未使用数据增强和使用数据增强条件下的训练结果。从表1中可以观察到未使用数据增强的模型,训练集上的准确度良好,但是验证集上的准确度不高,验证集上的损失值大;使用数据增强的模型,除VGG19外,其余4个模型在训练集上的准确度均有下降,所有模型在验证集上的准确度均高于未使用数据增强的模型(19.43%~29.73%),损失值也均大幅下降。经分析,产生这种结果的原因是模型背后学习算法的优化目标(减少损失值,最终逼近真实值)。试验所用的ChineseSnakes数据集的样本量少,用来识别的网络模型性能均较为优秀,模型的性能大于需要拟合的数据复杂度,因而未使用数据增强扩充样本的前模型能很快拟合训练集样本的特征,但没有学习到不同蛇类的通用特征,导致训练集准确度很高,验证集上的准确度很低,损失值居高不下,出现较为严重的过拟合。而数据增强扩充样本丰富了样本的多样性,模型对训练集的过度拟合得到缓解,训练集准确度下降,更多的样本提供了更多的分类特征,所以验证集上的准确度得到的较大提升,损失值下降。VGG19相比其他模型,在数据增强之后,训练集准确度没有下降反而提升,可能是因为VGG19模型出现更早,深度较浅,结构较为简单,在过拟合不严重的情况下,样本的增多仍能提升训练集上的准确度。综上,通过数据增强扩充后,虽然提升了分类效果,减少了模型的过拟合,但准确度仍有待提升。

表1 数据增强方法对训练结果的影响

4.2 预训练权值初始化下的迁移学习

4.2.1 微调前的迁移学习

在数据增强扩充数据集的条件下,用迁移学习的方式,通过加载大型数据集ImageNet上训练学习得到的神经网络权值,对5种神经网络模型初始化,将除了分类层网络以外的所有神经网络设置为不可训练模式,保留迁移的权值,仅训练分类层的神经网络,试验结果如表2所示。采用迁移学习初始化的模型,相较于随机初始化的重新训练,除了VGG19,均表现出更好的分类效果,验证集准确度提升8.39%~25.18%,基本未出现过拟合。随机初始化从零学习基本要训练400轮才能达到收敛,验证集的准确度提升缓慢,而使用迁移学习训练,只要300轮就能达到较好收敛,前半段的准确度提升迅速。这是因为在ImageNet上训练过的模型学习到了蛇类图像分类的有用特征。试验结果中验证集准确度明显优于训练集准确度的现象,是因为BatchNormalization、Dropout等图层会影响训练的准确性。在计算验证损失时,它们处于关闭状态,所以验证集的准确度高于训练集准确度。

表2 迁移学习下不同模型的训练结果

4.2.2 对迁移学习模型使用微调策略

微调的对象为特征提取网络的顶层部分,保持分类层网络可训练,在此基础上,从冻结的神经网络顶层开始将部分的网络层设置为可训练状态,使其参数能在训练中改变,微调基础模型中的高阶特征表示,以使它们与蛇图像分类任务更相关。试验中的5种模型,将特征提取网络的前5层、前10层和前15层顶层网络设置为可训练状态,分别进行试验。由于参与训练的神经网络层变多,预计进一步提升模型的分类准确度,在这一阶段应选择较低的学习率,因此设置学习率为微调前的1/10,即为0.000 01。

如图3训练结果曲线所示,在对模型进行微调前,验证集的准确率由于BatchNormalization、Dropout等神经网络层在评估结果上的影响,验证集的准确度大于训练集的准确度,在经过300轮的训练后,各模型基本达到收敛。对模型进行微调后,再经过100轮的训练,由于可训练的参数增加,模型的拟合能力再次超过数据集的复杂程度,出现过拟合现象,但是各模型在训练集和验证集上的识别准确度均获得进一步提升。针对模型特征提取网络顶层部分的神经网络进行不同层数的解冻并训练,选择提升效果最好的微调策略与微调前的分类效果进行比较,试验结果如表2所示,MobileNetV2、Xception、VGG19、ResNet50和ResNet101在验证集上的识别准确度分别提升了6.87%、7.05%、12.39%、1.25%和1.67%,各微调后的神经网络模型在测试集上的准确度最高的分别为95.80%、95.80%、95.10%、96.50%和97.20%,对ChineseSnakes数据集的平均识别准确率为96.08%,可以达到一个较好的识别效果。

图3 训练集/验证集准确度和损失值曲线Fig.3 Accuracy and loss curve of train/validation

表3 基于微调策略的迁移学习训练结果

5 结论

试验结果表明,本研究提出的分类模型改进方案可以有效提升模型的蛇类识别效果,并达到一个较高的识别准确度,证实了方案的有效性。数据增强可以减少模型的过拟合,提高泛化能力;Adam优化算法帮助模型寻找最优解;采用迁移学习的初始化网络权值方法可以大大缩短模型的训练时间,减少过拟合,提升模型分类精度;对模型进行微调,调整模型中的高阶特征,使其适应指定数据集,进一步提升了模型在ChineseSnakes数据集上的识别准确度。最终改进的模型通过数据增强、Adam优化算法、迁移学习和微调策略,5种深度学习的神经网络模型在ChineseSnakes数据集上的平均识别准确度为96.08%,其中识别准确度最高的为ResNet101模型,在测试集上的识别准确度达到97.20%。轻量型神经网络模型MobileNetV2和Xception相比传统的神经网络模型略有不足,但是相差不大,满足日常识别的要求,其参数少和体积小的优势,会是实现嵌入式蛇类自动识别系统的不错选择。

猜你喜欢

蛇类集上准确度
基于标记相关性和ReliefF的多标记特征选择
灭绝恐龙的灾星竟是蛇类繁盛的福星
灭绝恐龙的灾星 竟是蛇类的福星
影响重力式自动装料衡器准确度的因素分析
关于短文本匹配的泛化性和迁移性的研究分析
基于YOLOv4的蛇类图像识别
经济蛇类养殖与开发利用
论提高装备故障预测准确度的方法途径
Word中“邮件合并”功能及应用
师如明灯,清凉温润