基于轻量级网络的眼表疾病识别方法研究
2023-11-25陈荣周子昂姜永春谢鹏飞
陈荣,周子昂,姜永春,谢鹏飞
(1.青岛黄海学院大数据学院,山东青岛 266427;2.哈利法科学技术大学工程学院,阿联酋阿布扎比 127788)
0 引言
眼表是由结膜、角膜及其附件结构组成的特殊黏膜系统。眼表疾病(Ocular Surface Disease, OSD)泛指任何损害眼表系统结构和功能的疾病[1]。临床上常见的眼表疾病包括翼状胬肉、结膜色素痣和角膜炎等,严重降低日常生活质量,影响着全球20%以上的人口。眼科医生通常使用专业的医疗资源来检查OSD,如裂隙灯、共聚焦显微镜和光学相干层析成像等[2]。但这些传统医学检查方法存在一些缺点:专业医疗设备价格昂贵且只能在医院使用、患者就诊成本高以及专业眼科医师供需紧张。
近年来,人工智能在医学图像识别方面取得了巨大进展和突破,具有高分辨率摄像头的智能手机可以轻松获取高质量的眼部表面图像数据,许多常见的眼表疾病在不借助医疗仪器设备的情况下,就已具备十分清晰且易辨别的临床特征,这使得依靠智能手机采集的眼表照片来自动筛查疾病成了可能[3]。因此,通过人工智能技术处理和分析智能手机眼表照片,可以及时筛查和监控眼表疾病的发展。
目前许多研究团队采用经典的卷积神经网络方法,在眼表疾病识别上取得了一定的进展。Li 等[4]采用卷积神经网络对角膜炎、翼状胬肉等常见的眼表疾病进行自动识别。Xu 等[5]使用5 种深度学习算法(VGG-16、ResNet-101、InceptionV3、Xception 和Inception-ResNetV2)自动检测和评估角膜炎患者。但是研究数据依然是裂隙灯和共聚焦显微镜等医学仪器采集的眼表图像。后来,Li 等[6]融合裂隙灯图像和智能手机图像构建AI角膜疾病筛查系统,研究中对比了GoogLeNet、ResNet 和DenseNet 三种网络的性能。Chen 等[7]使用改进的DenseNet 方法对智能手机采集的眼表图像来诊断眼睛表面是否患病。尽管以上方法取得了良好效果,但先进的网络模型通常存在大量的参数和较深的网络层,导致在嵌入式设备、智能手机等低资源平台上难以部署。
为了进一步平衡计算资源和识别性能的关系,许多轻量级网络被研究,与用标准卷积(Conv)来构建深度网络不同,设计轻量的构建单元被证明是开发更轻、更高效网络架构的有效途径。MobileNetV1 引入深度可分离卷积(DWConv)轻量单元开发了一个全新高效的轻量级网络,并应用于移动端视觉任务,进而通过跳跃连接的倒残差模块提高了性能[8]。Zhang等[9]利用分组卷积和通道混洗操作构建了ShuffleNetV1和ShuffleNetV2。华为设计的GhostNet[10]提出廉价且高效的幽灵(Ghost)模块来生成更多样化的图像特征。CondenseNetV2[11]提出稀疏特征重激活模块构建轻量网络,增加了特征的利用效率。轻量级网络的出现,在一定程度上弥补了深度网络计算效率的问题,但轻量单元提取图像特征的效果较差,从而导致识别眼表疾病的准确率较低。
针对上述问题,本文以标准卷积为主,引入深度可分离卷积和Ghost模块这两个轻量单元来辅助构建一个聚集模块,以低成本和高效的方式并行学习丰富多样的眼表图像特征,增强特征提取的能力。并且基于聚集模块进一步开发轻量级网络应用于眼表疾病识别,大大减少了网络模型的参数量,实现了眼表疾病的精准筛查,较好地平衡了计算效率和识别性能之间的关系。
1 网络模型概述
1.1 聚集模块
聚集模块先将标准卷积、深度可分离卷积和Ghost 模块集成到一起,并行地提取并拼接眼表图像特征,再采用通道混洗操作实现不同特征通道之间的信息交流,如图1所示。该模块每条路径生成相同数量的特征图,代表不同特征提取方式学习到的信息,可以增加提取眼表图像特征的多样性。
图1 聚集模块
深度可分离卷积使用经过分解的卷积算子来替代完整卷积核提取眼表图像特征。它把标准卷积分成两个独立的层,能够打破输出特征通道数量与卷积核尺寸之间的交互。第1层是深度卷积层,对每个输入通道用单个的n维卷积核来实现轻量的滤波操作;第2 层是逐点卷积层,对上一层的输出用尺寸为1×1的卷积核进行滤波操作。
Ghost 模块由部分通道的卷积和矩阵变换组成,如图2 所示。针对n维的输入,首先采用部分通道的卷积核(u≪n)生成u个特征图,再对其每一个输出特征图都进行高效的矩阵变换来计算得到新的特征映射,最后将部分通道卷积和矩阵变换所得的所有特征图拼接起来,就得到了Ghost模块的全部输出结果,即n=u+u(v-1)个特征图。简单高效的Ghost 操作可以与任何的矩阵变换运算相结合,如小波变换、仿射变换和分组卷积等。
图2 Ghost模块
在微型网络中,单个的轻量单元倾向于通过限制通道数量来约束网络复杂度,这严重影响了网络的精度。聚集模块中采用三条路径(DWConv、Conv 和Ghost)来生成多样化的特征图,这些输出特征图直接拼接在一起,会产生一个副作用:每个通道的输出只能由输入通道的某一小部分获得,不同路径之间的特征无法交互,严重阻塞了特征通道之间的信息流,进而降低了特征的表达能力。因此,本文引入通道混洗操作对每条路径输出的特征图进行重新排列,保证后面的输入来自前面不同的提取特征路径,以达到特征融合的目的。通道混洗操作先将DWConv、Conv 和Ghost 这三条路径的输出看作一维向量,再分解为二维矩阵,最后转置后拉伸展平为一维,这样所有输出的特征通道都得到了重新排列。
1.2 轻量级网络—聚集网络
本文提出的轻量级网络基本遵循了密集连接的方式来设计,为了直接减少网络计算量,放弃了“网络足够深”的概念,在密集块中使用较少数量的聚集模块来提高特征提取效率。另一方面,ShuffleNetV2 曾提出了四条搭建轻量级网络的指导原则:相同大小的通道可以最小化访问内存;分组卷积数目过多会影响计算效率;网络碎片化会降低程序的并行能力;元素级操作对网络效率影响较大(如ResNet 中的跳跃连接)。本文根据以上4条规则来开发聚集网络,第1条采用聚集模块作为基本构建单元,其中每条路径的输出通道大小保持相同以减少内存访问成本。Ghost模块中仅有少量的分组卷积操作,对应第2条原则。针对第3条,本文采用密集连接将所有输出特征拼接起来形成了一个整体,避免了网络碎片化的产生。最后,网络中的梯度传播是采用串联拼接,而非跳跃连接。
1.3 本文网络结构
图3展示了本文自动识别眼表疾病的轻量级网络结构,它由1个初始的聚集模块、3个密集块、2个过渡块和一个分类层组成。初始的聚集模块用来提取全局图像特征。每个密集块只包含4个聚集模块层,被用于提取图像特征。聚集模块中每条路径的输出表示该路径贡献的新信息。每个聚集模块后面都包含一组批归一化处理(BN)、激活单元和1×1 Conv的集合函数。过渡块将2个相邻的密集块连接起来,通过减少特征图的大小和数量来提高计算效率。每个过渡块由4 部分组成:BN、激活单元、1×1 DWConv 和平均池化层。高效的DWConv 将通道数量减少一半,步长为2的平均池化层用来减少特征图的一半尺寸。分类层将前面所得到的特征通道映射为2个特征图,代表预测属于正常或异常眼表的类别。整体的网络主干如图4所示。
图3 轻量级网络结构
图4 主干网络结构
2 实验及结果分析
2.1 实验数据
本文实验数据集为Github网站的开源数据集,眼表照片均采用智能手机拍摄。数据集中眼表图片共953 张,其分辨率为682×512,每张照片的症状由多名专业眼科医生共同标注,其中正常眼表467张,异常眼表照片486张。本文将数据集以7:3的比例分为训练集和测试集,并采用水平翻转和随机裁剪来扩充数据集,以缓解网络训练时数据量较少导致的过拟合问题。
2.2 实验环境
硬件环境:显存12GB 的NVidia GeForce GTX 1080 Ti GPU;软件环境:Windows 10、Python 3.7、深度学习框架Pytorch 1.6.0、CUDA 10.1。
2.3 训练参数设置
本文实验采用交叉熵损失函数和带动量的随机梯度下降算法训练模型,在优化算法中参数动量和衰减因子分别设置为0.9 和0.000 1。网络训练轮次Epoch 和批尺寸Batchsize 分别设置为200 和16,初始学习率设置为0.15,训练到第100和150个Epoch时将学习率分别衰减10倍。实验过程中的训练损失、测试误差与Epoch 的关系如图5 所示,可以看出损失曲线的整体趋势比较平滑,早期训练损失下降较快,在经过100个Epoch之后,模型慢慢趋于收敛。
图5 聚集网络识别OSD的训练曲线
2.4 结果分析
为了验证本文所提出的轻量级聚集网络方法识别眼表疾病的性能,先将其应用于眼表图像数据,计算出识别效果;再与ShuffleNetV2、MobileNet 系列、GhostNet 和CondenseNetV2 等当前先进的轻量级网络方法进行对比实验,以证实其高效性。
图6展示了聚集网络自动识别眼表疾病的混淆矩阵。模型将13 例正常眼表误分类为异常,将14 例患病眼表识别为正常,其中近一半被误判的眼表图像为结膜充血一级,主要因为患有轻度充血症状的眼表图像与正常眼表十分相似,图像显著性特征不够明显难以分类。同时计算了眼表疾病识别的特异性、召回率、精度、F1 分数和准确率,分别达到了90.71%、90.41%、91.03%、90.72%和90.56%,各项指标均在90%以上,说明总体上聚集网络对眼表图像的识别效果良好。
图6 聚集网络识别OSD的混淆矩阵
表1对比了不同轻量级网络方法对眼表疾病自动识别的结果。本文所提出的聚集网络模型参数量(Paras)为0.24M,比其他轻量级网络减少到原来的1/10。聚集网络仅需要1.88B的计算量(FLOPs),尽管Ghost-Net模型采用的高效幽灵模块和减少碎片化操作在计算效率方面比较有效,只需要0.82B的FLOPs,但其识别效果不佳,而聚集网络则实现了90.56%的眼表疾病识别准确率(ACC)。因此,当前先进的轻量级网络方法在识别眼表疾病方面的准确率普遍较低,本文所提出的基于聚集模块开发的轻量级网络能够较好地实现计算成本与性能之间的平衡。
表1 模型实验结果
3 结论
采用智能手机拍照来自动筛查眼表疾病,对医生辅助诊断、后续精准治疗以及康复具有重大意义。本文针对深度卷积神经网络难以在可移动设备端部署、现有轻量级网络识别效果不佳和较难平衡计算资源与性能的关系等问题,提出了基于聚集模块的轻量级眼表疾病自动识别方法。该方法构建的聚集模块以低成本、高效的方式并行提取丰富多样的眼表特征,能够改善不同特征通道之间的信息流。通过实验对比证明,本文所提出的方法以更少的参数量和计算成本实现了更高的识别准确率。但该模型参数量较少,当眼表数据集足够大时,其参数不足以表示海量数据中的所有图像特征,在模型训练过程中,聚集网络只会尽可能多地学习大部分可分辨的特征,从而表现出一定的泛化能力。未来可以进一步优化网络体系架构,并考虑神经网络架构自动搜索技术,找到合适的网络宽度和深度来实现速度与识别性能之间的最优平衡。