基于改进Mask R-CNN的多标签甲状腺结节检测模型
2023-06-07吴雯娟邓梓杨邱桃荣张卫平
吴雯娟,戚 琪,邓梓杨,邱桃荣*,张卫平,徐 盼
(南昌大学a.数学与计算机学院,江西 南昌 330031;b.第一附属医院超声科,江西 南昌 330006)
全球甲状腺癌发病率呈上升趋势,为了能够提升甲状腺结节诊断水平,缩小不同医疗条件下甲状腺结节诊断水平的巨大差异,中华医学会超声医学分会浅表器官和血管学组组织专家于2017年开始着手起草符合中国国情的C-TIRADS[1-3]。C-TIRADS标准在临床实践中表现了很好的诊断效果[4]。基于计算机辅助诊断的方法来自动检测识别甲状腺结节能够提升临床医生的诊断效率,降低劳动强度[5-7]。因此,针对C-TIRADS来研究基于深度学习的甲状腺超声图像诊断方法具有重要意义。结合C-TIRADS标准构建一个能定位结节并预测结节恶性风险等级以及预测结节的病理特征的多标签目标检测模型,能够更好地与当前临床诊断流程结合,为临床医生提供更多参考依据,提升临床医生的诊断效率。
为了能够有效结合C-TIRADS标准实现甲状腺结节恶性风险智能诊断,本文提出一个基于深度学习的多标签甲状腺结节检测模型。通过在本文数据集上对比Mask R-CNN[8]、YOLOv5[9]、DETR[10]模型对甲状腺结节的检测性能,选用了在数据集上甲状腺识别效果最好的Mask R-CNN模型作为基准模型。为了解决基准模型无法完成多标签预测以及检测效果不理想的问题,将基准模型的特征提取网络、检测头、锚框进行了优化改进,使用ResNet152-FPN替换原有特征提取网络来提升模型特征能力,设计全新的卷积多标签检测头结构来对结节病理特征进行多标签预测,基于医学先验知识对模型锚框尺寸及比例进行自定义来提升模型的定位精度,最后采用迁移学习的方式对改进后的模型进行训练,使模型进行得到进一步提升。
1 目标检测模型及特征网络简介
1.1 Mask R-CNN
Mask R-CNN模型是两阶段目标检测模型。模型的工作流程可以分为在图像上生成推荐候选区域、对候选区域图像进行特征提取与预测两步,工作流程如图1所示。
图1 Mask R-CNN模型工作流程图Fig.1 Workflow of Mask R-CNN
输入图像经过特征提取网络得到特征图,特征图输入RPN网络用以生成候选框,候选框应用在特征图上生成感兴趣区域,感兴趣区域输入到检测头中用以分类与回归框预测。
1.2 ResNet
何凯明等[11]提出了ResNet网络与Residual残差结构,Residual结构如图2所示,通过快捷连接将上几层的输出与该层输出相加。Residual结构的提出使得ResNet模型具有了恒等映射能力,ResNet模型使用带有快捷连接的残差卷积块替代普通卷积块,解决了构建深层网络可能出现的退化问题。并且在ResNet模型中引入了Batch Normalization(BN)层,解决模型训练过程中可能出现的梯度消失/爆炸问题。
图2 Residual模块结构Fig.2 Structure of Residual module
表1 ResNet模型结构Tab.1 Structure of ResNet
表1展示了不同深度ResNet模型的结构,模型由不同数量的多种卷积块堆叠构成。ResNet152相比ResNet101具有更多的卷积层,理论上具有更好的表达能力。
1.3 FPN
FPN(Feature Pyramid Networks)[12]网络结构如图3所示。左侧为普通的卷积神经网络,右侧为FPN从卷积神经网络中提取的不同层次的特征图,并把它们通过侧边连接相连,使得FPN网络中的每一层特征图都包含多个尺度的信息,然后再基于特征金字塔中各个层次的特征图分别进行预测,论文中实验数据表明,FPN网络能够提升一些目标检测模型的性能。
图3 FPN模型工作流程图Fig.3 Workflow of FPN
2 基于Mask R-CNN的多标签模型构建与优化
2.1 特征提取网络的选取与改进
基准模型选取实验中Mask R-CNN模型使用ResNet101作为特征提取模型进行训练,模型的性能还有进一步提升的空间。为增强模型的特征提取能力,本文选取更深的ResNet152模型作为特征提取网络进行对比实验。有研究表明FPN模型能够提升目标检测模型的性能,所以本文将在ResNet模型上加入FPN结构进行训练与实验对比,通过对比实验为模型选取最佳特征提取网络。ResNet-FPN模型的结构如图4所示,左侧为ResNet网络,FPN通过提取ResNet网络中不同层次特征图进行连接得到特征金字塔。
图4 ResNet-FPN工作流程图Fig.4 Workflow of ResNet-FPN
2.2 预测分支结构设计
为Mask R-CNN模型构建能够进行多标签分类预测的检测头,本文提出的检测头结构如图5所示。
图5 检测器结构示意图Fig.5 Structure of Detectors
其中class1分支代表一个二分类预测分支,用以预测感兴趣区域是否为甲状腺结节,为原Mask R-CNN模型自带类别预测分支结构,class2分支为多标签分类预测分支,是在原模型基础上为完成多标签预测结节病理特征新增加的分支结构。本文设计了卷积检测头与全连接检测头两类,并且尝试在class2预测分支上增加约束,将class1分支预测结构输入到class2分支来进一步约束class2分支预测结果。本文将在不同的检测头上进行对比实验,将实现中表现最好的检测头作为改进模型的检测头结构。
2.3 基于临床先验知识指导的锚框定义策略
以特定的先验医学知识为指导模型改进能够提升模型性能[13-15]。本文为了使模型能够生成更加贴合甲状腺结节形态的候选框,对数据集中甲状腺结节的纵横比进行统计发现纵横比范围在(0.2,1.8),结节的尺寸越大,结节的纵横比往往更小,基于数据集中结节的纵横比分布情况对锚框的尺寸和纵横比进行定义,具体数据见表2。
表2 不同层级特征图上的锚框大小及纵横比Tab.2 Size and aspect ratio of anchor on different level feature map
2.4 基于迁移学习的改进模型预训练
迁移学习通过将在任务A上学习到的规律应用到任务B上,通过一些已经学习到的A、B任务的共同规律来节省B任务的学习成本,同时提升任务B的泛化性能。为了更好地训练模型,本文采用迁移学习对模型进行训练,为特征提取网络、检测头分别设计了训练方案。
特征提取网络的训练步骤,首先冻结除检测头以外的所有参数,对检测头参数进行训练,然后对整个模型中所有参数进行训练微调。
本文为卷积检测头和全连接检测头分别设计了两种训练方案,针对全连接检测头的训练方案有:(1)冻结所有已经训练过的卷积层,对新构建的检测头进行训练(2)冻结所有已经训练过的卷积层和全连接层,仅对新增加的全连接分支进行训练,再对整个模型进行微调。针对卷积检测头的训练方案有(1)冻结所有已经训练过的卷积层,对新构建的检测头进行训练(2)冻结所有已经训练过的卷积层与全连接层,先训练检测头中新增加的残差块,再对新增加的多分类全连接层进行训练,再对整个模型进行微调。
本文将通过实验对比特征提取网络采用迁移学习和非迁移学习策略训练的模型性能,对比各卷积头在不同训练方案下的模型性能。
3 模型对比实验与结果分析
3.1 甲状腺超声图像数据集
本文构建的甲状腺超声图像数据集具体数据如表3。对于超声图像中的单个结节可能同时具备多个病理特征标签。在实验中将原始数据集按照8:2的比例进行划分,80%的数据作训练集,20%的数据作为测试集。
3.2 不同目标检测网络的对比实验与结果分析
本文为了选取基准模型,将数据集输入到Mask R-CNN、YOLOv5、DETR模型中进行训练对比,具体实验数据见表4。
从表4可以看出,Mask R-CNN模型在IOU阈值为0.5和0.75条件下对甲状腺结节的检测识别性能在三个模型中都是最佳,因此本文选取Mask R-CNN模型作为基准模型。
表3 数据集统计Tab.3 Dataset statistics
表4 各模型对甲状腺结节的检测性能Tab.4 Detection performance of different models for thyroid nodules
3.3 改进模型的不同主干网络与训练方式对比实验与结果分析
本文数据集在基准模型上对甲状腺结节的识别检测性能不高,因此本文首先对基准模型的特征提取网络进行改进,为模型寻找特征提取能力更强的特征提取网络,同时为了能更好训练模型对模型采用迁移学习的方式进行训练,具体实验数据见表表5。
表5 不同主干网络及训练方式的对比实验结果Tab.5 Results of different backbone networks and training methods
从表5可以看出,带有更深层的特征提取网络的模型对甲状腺结节的检测识别性能更好,且ResNet模型上增加FPN结构后性能得到明显提升。通过对比同一模型在不同训练方式下的实验数据可以看出,迁移学习策略能够为模型性能带来增益。
3.5 预测分支结构与损失函数设计
本文提出了多个检测头结构,对各检测头采用2.4节中描述的方案进行训练,具体的实验数据如下表。
表6 不同预测分支结构与训练方式的实验结果Tab.6 Experimental results of different prediction branch structures and training methods
从表6可以看出,检测头d在两个检测分支上的性能都优于其他检测头,且各分支在使用第二种训练方式进行训练时性能更佳,因此将检测头d确定为改进模型的检测头。
3.5 基于临床先验知识指导的锚框定义策略
为了能够使模型生成更加贴合甲状腺形态的推荐候选框进而提升模型对甲状腺结节的定位精度及分类准确率,本文基于临床先验知识对模型锚框尺寸进行定义,具体实验数据如下表。
表7 基于临床先验知识的锚框实验结果-aTab.7 Results of anchor definition based on clinical prior knowledge-a
表8 基于临床先验知识的锚框实验结果-bTab.8 Results of anchor definition based on clinical prior knowledge-b
从表7可以看出,有了临床先验知识指导锚框尺寸定义后模型性能在class1分支得到提升,根据表8可以看出,模型在class2分支上对各项病理特征的识别准确率都得到提升,尤其在4号特征上提升最为明显。
3.6 改进模型的预测结果展示
图6展示了模型对甲状腺结节的预测效果,左侧为对甲状腺结节轮廓进行标注的原始图像,右侧为模型对结节的预测结果,预测框定位了结节位置,预测框上方为模型对结节特征及恶性风险等级的预测结果。
图6 模型检测效果展示Fig.6 Model prediction result display
4 结论
本文构建了一个自动判别甲状腺结节恶性风险的多标签目标检测模型,对模型进行了优化改进,并且在数据集上开展了充分的对比实验来验证模型改进效果。从实验结果上看,对模型特征提取网络、检测器结构、锚框尺寸及大小的改进均有效提升了模型对甲状腺结节的识别检测性能,采用的迁移学习训练策略也帮助模型进行了更好的训练。改进模型对甲状腺结节的检测准确率达到了94.4%,对病理特征的平均识别准确率达到了88.6%。
后续将基于临床经验指导进一步探索各项特征间的隐藏关联,对检测器等进行进一步改进优化以提升模型对结节特征的分类准确度。同时也不断的扩充数据集来加强模型的泛化性能。