基于密集残差连接的肺结节检测方法*
2024-03-06佘青山张建海
胥 阳,佘青山*,杨 勇,张建海
(1.杭州电子科技大学自动化学院(人工智能学院),浙江 杭州 310018;2.浙江省脑机协同智能重点实验室,浙江 杭州 310018)
癌症是对人生命健康的一大威胁,其中,肺癌是一种发病率与致死率较高的癌症[1]。肺癌的症状在早期并不明显,当发现较为明显的症状时,病人的肺癌已到中后期,此时病人的肺癌治愈率已大幅降低,基本无法治愈。因此,在肺癌的早期发现肺癌的症状并治愈肺癌显得十分重要。在肺癌早期,肺癌的症状以肺部出现结节为特征,这些结节的大小不一;其形状也各不相同,有实性结节、部分实性结节、磨玻璃结节;结节的位置也不一定,有的在肺内部与其他组织没有接触,有的附着在肺内部的血管与组织上,有的在肺的内表面上。医学上对肺结节的检测多使用计算机断层扫描(Computed Tomography,CT)[2]来对肺部成像,再由有经验的医师查看肺部的CT 图像判断病人肺部是否含有结节。由于CT扫描的成像特点是会形成大量的肺部扫描切片,查看肺部的CT 成像对医师来说是一个很大的挑战[3]。
随着计算机技术的发展,计算机辅助诊断(Computer Aided Diagnosis,CAD)[4]应运而生,计算机辅助诊断系统通过运行特定程序帮助医生查看CT 图像,缓解了医生负担[5]。随着深度学习技术的发展,以及深度学习在目标检测领域的进步[6-11],使用卷积神经网络(Convolutional Neural Network,CNN)[6]分析医学图像已成为一种趋势[12-13]。在使用深度学习检测肺结节的领域,已有许多学者提出了很多优秀的方法,这些方法在结构上可以大体分为两大类:一类是端到端的,利用单个网络完成对CT 图像中肺结节的检测;另一类是使用两个网络,第一个网络检测肺结节,由于单个网络检测出的肺结节含有较多的假阳结节,然后通过第二个网络去掉这些假阳结节,降低结果的假阳性[14]。鉴于CT图像的3 维特性,许多方法使用3D CNN 检测肺结节。对比研究表明,基于3D CNN 的检测肺结节方法在精度上优于2D CNN 方法[15]。Setio 等[16]提出一种多个视角的肺结节检测网络,将肺结节不同方向的2 维切片输入网络,将结果融合输出以检测肺结节。Song 等[17]提出的中心匹配策略的肺结节检测网络,无需设计边界框的参数,可以自动检测肺结节的大小与位置。Kim 等[18]提出一种端到端的多尺度渐进方法,通过学习肺结节不同上下文级别的多尺度特征检测肺结节。Ding 等[19]提出使用Faster R-CNN 算法检测肺结节,在网络中加入反卷积结构,检测CT 扫描轴向切片上的肺结节。Zhu等[20]提出了一个端到端的一阶段的肺结节检测模型,通过在U-Net 框架上融入提出的双路模块检测肺结节。Liao 等[5]引入泄露噪声或门(Leaky Noisy-OR Gate)的机制检测肺结节,基于检测结果的置信度选择前五个结节,并通过区域建议网络得出结节信息。Harsono 等[21]提出了结合迁移学习得到的RetinaNet 网络,用于肺结节检测与分类。
以这些检测网络为代表的肺结节检测方法,在计算机辅助诊断肺结节方面都取得了较好的进步,但这些肺结节检测方法多为顺序堆叠卷积核的卷积网络,在检测网络中没有实现对肺结节特征的复用,并且不同深度间的特征没有信息交流,从而对肺结节检测准确率低[22]。对此,本文提出基于密集残差连接的肺结节检测模型。本模型在3D U-Net 网络[23]中引入密集连接、残差连接与注意力机制[24],实现检测网络对肺结节特征的充分利用,实现对肺结节特征的多尺度提取,提高对结节特征的利用效率,从而提高肺结节的检测精度,解决肺结节检测结果假阳率高的问题。
1 数据集与方法
1.1 数据集
本文实验选择的数据集为LUNA16 数据集[25]。LUNA16 数据集包含888 份CT 图像与1 186 个有标注结节,这些结节标注文件由3 到4 位医生共同认定。数据集中CT 图像的切片厚度小于2.5 mm,分辨率为0.46 mm~0.98 mm。此数据集中,结节的直径分布为3 到30 mm,但多集中于3 mm~10 mm 内。
1.1.1 数据集预处理
数据集中原始的肺部CT 扫描图像中不仅包含肺结节检测区域的肺实质,同时还包括肺部周围的其他无关组织。为节约计算资源,提高检测效率,在检测结节之前需要利用图像形态学的方法将关组织去除。处理肺实质包括以下几个步骤:
步骤1 由于模型检测所用到的CT 图像是由不同扫描仪器得到,为了消除不同CT 图像间的差异,重采样所有CT 图像的体素到1 mm×1 mm×1 mm。
步骤2 根据人体组织CT 扫描成像的特点,将CT 图像的亨氏单位(Hounsfield Unit,HU)阈值按数值保留[-1200,600]区间的数据,超出此区间的体素则一律填充-1 200 或600。将图像8 位数字图像数据归一化到0 到255。
步骤3 将CT 图像对应掩码膨胀,对CT 图像使用新掩码后,将掩码外数据统一设置为170,表示肺实质的外部为水。
步骤4 得到处理后的肺实质,以*.npy 文件形式保存。
肺实质处理前后CT 图像变化的一个例子如图1所示。图1 中,第一行为原始的CT 图像,第二行为处理后的CT 图像。可以看出,在经过了一系列处理后,CT 图像去除了肺实质以外的无关区域与组织,同时增强了肺实质内部细节,使得CT 图像更加清晰,便于检测。
图1 肺实质处理前后对比图
1.2 基于密集连接的肺结节检测方法
在基于深度学习的肺结节检测网络中,大多为全卷积网络,网络使用卷积核对特征图卷积得出需要的结节信息。肺结节检测模型多以U-Net 网络为主干,采用卷积与下采样,逐步提取低维到高维的特征,利用转置卷积将特征图恢复到适合的尺寸,最后得出对应的结节信息。然而这种模型计算量大并且不能复用网络中大量的肺结特征,同时,堆叠普通卷积核在模型网络深度很深时会出现梯度消失的问题[22],这会导致网络无法进行反向传播,无法训练。He 等[26]在 2016 年提出残差网络(Residual Network,ResNet)模型,模型的核心思想为在普通卷积层前后增加一路快捷连接,卷积层的输出与快捷连接的输出在元素层面相加作为ResNet 的输出,通过增加快捷连接通路,在原有的卷积层出现梯度消失的情况下网络仍然能够传播。但是ResNet 的残差连接依旧是对输入的特征采取顺序操作,没有能够充分地利用不同深度的特征。Huang 等[27]在2017 年提出一个全新的名为密集网络(Densely Connected Convolutional Network,DenseNet)的网络结构,该结构可以将各个层的特征利用起来,使用拼接操作将之前各个层的特征作为下一层的输入,这种方式充分地利用不同层的特征,其数学表达为
式中:hi[·]表示第i层的拼接操作,xi,i=1,2,…,n表示特征图。
鉴于密集连接与残差连接各自的优点,本文提出结合密集连接与残差连接同时融入注意力机制的密集残差模块。
1.2.1 密集残差模块
为了提高网络中对结节特征的利用效率,本文提出了一个由密集连接、残差连接与注意力机制构成的网络结构,如图2 所示。在主干通路上有3 组3D 卷积,卷积核大小为3,每组卷积块包含卷积层、归一化层与线性整流单元(Rectified Linear Units,ReLU)。3 组卷积块采用密集连接的方式输入,卷积块之前各层卷积块的输出拼接后作为下一个卷积块的输入,图中曲线箭头表示各个层的特征图。三个卷积块后是注意力结构,注意力结构可以对特征图进行权重分配,增加与结节检测相关的特征图的权重,减小与结节检测无关的特征图的权重。经过权重分配的特征图与传入密集模块的初始特征图在元素层面相加后作为密集模块的输出。
图2 密集残差模块
密集模块中的注意力结构具体见图3,在通道维度上使用全局平均池化(Global Average Pooling,GAP)压缩输入的特征图到通道维度的向量,然后经过两层全连接层(Fully Connected Layers,FC)学习得到注意力系数,使用Sigmoid 激活函数将系数映射到0 到1 的范围,最后得到一组与输入特征图通道数对应的注意力系数。注意力系数与输入的特征图在通道维度相乘,得到注意力权重分配后的特征图。
1.2.2 肺结节检测网络
肺结节检测网络结构如图4 所示,网络的主要部分可分为数据预处理、特征提取与检测。特征图左下角为特征图大小,右上角为通道数。由于检测的CT 图片的通道数为1,不便于深度学习网络进行特征提取,利用数据预处理块提升通道,便于后续提取特征。特征提取模块利用搭建的针对性的网络结构提取结节特征。检测模块可以输出特定的数据格式,得到输入CT 图像中肺结节位置与概率信息。
图4 肺结节检测网络
如图4 所示,网络的输入是在肺实质中截取的包含肺结节的3 维CT 图像块,图像块的边长为96。数据预处理块包含2 个核大小为3 的3D 卷积层,两个卷积层的层数都为24。特征提取模块在传统的3D U-Net 网络即编码解码结构的基础上改进,使用1.2.1 提出的密集残差模块代替普通卷积块。
编码部分由4 个核大小为2 的3D 最大池化层与4 组密集残差模块交错排列组成,前两组分别包含2 个密集残差块,后两组分别包含3 个密集残差块。解码部分为两组核大小为2 的3D 转置卷积层与2 组密集残差组交错排列,两组密集残差组中均包含2 个密集残差块。采用跳跃连接方式连接编码部分与解码部分,在编码部分的第二组与第三组密集残差模块的输出与解码部分的第二组与第一组转置卷积层输出拼接。同时为了融合网络中的不同深度与尺度的特征信息,在跳跃连接处设置特征信息增强机制,具体为编码部分第二组密集残差模块的输出与第三组密集残差模块的输出经转置卷积后拼接,然后经过一包含3 个密集残差块的密集残差组,共同作为第二组密集残差模块的跳跃连接与解码部分相连接。特征提取模块后接Dropout 操作,以0.5的概率对网络随机失活以降低网络过拟合。检测模块包含两个核大小为1 的3D 卷积层,通道数为64与15。受区域建议网络(Region Proposal Network,RPN)[10]启发,在网络中,特征图的每个位置上都有3 个不同大小的边界框,边界框的长度根据结节直径的分布确定,分别为5 mm、10mm 和20 mm[28]。输出的结果为输入检测网络的特征图中包含结节的位置坐标与概率,分别为结节的概率p与结节在CT图像中坐标(x,y,z)与直径大小d。
1.2.3 损失函数
本结节检测模型的损失计算包含两个部分,一是检测区域内是否包含结节,即预测肺结节概率p的损失,采用分类损失,二是检测区域内结节的位置(x,y,z)与结节直径d的回归损失。在本模型中,使用检测区域与标注结节区域的交并比(Intersection over Union,IoU)[29]区分正负样本,如果两者区域的交并比大于0.5,标记为正,如果两者区域的交并比小于0.02,标记为负,其余则被忽略不被标记。分类损失采用二值交叉熵损失函数,回归损失采用Smooth L1 损失函数。
分类损失函数定义为:
式中:p表示模型预测的概率,p*表示标签的分类概率。当为正样本标签时p*=1,负样本标签时p*=0。
回归损失定义如下:
式中:L(ti,)表示Smooth L1 损失函数。
Smooth L1 损失函数定义如下:
式中:ti表示模型预测的结节回归坐标组t中参数,表示标签标注的结节回归坐标组t*中参数。
t与t*分别为预测的结节回归坐标组与标注的结节回归坐标组,定义如下:
式中:(x,y,z,d)表示模型预测的结节3 维位置信息与直径,(x*,y*,z*,d*)为标签中标注的结节3 维位置信息与直径,(xa,ya,za,da)为边界框的3 维位置信息与直径。
模型的总损失为:
式中:Lcls表示分类损失,Lreg表示回归损失。
2 实验设计与结果
2.1 实验设计
2.1.1 实验数据处理
CT 图像的数据量十分庞大,如果一次将整个CT 图像输入网络,对显存的消耗很大。本次实验在训练阶段选择截取较小的包含结节的立方块输入网络,在保证检测精度的同时减少对硬件消耗。具体做法为以数据集标注文件的结节坐标为中心,截取边长为96 的立方块作为正样本。在CT 图像内的非肺结节区域,截取同样尺寸的立方块作为负样本。为了解决数据的正负样本不平衡的问题,对正样本进行随机偏置,在0~360°内随机翻转和以0.75~1.25 之间系数进行缩放作为数据增强手段扩充正样本;同时为了平衡数据集中结节多为小结节的问题,对直径在10 mm~20 mm 和20 mm 以上的结节进行直接复制的方式,扩充中直径结节与大直径结节的数量,平衡结节的直径分布。在模型的测试阶段,CT 图像中依次截取边长为128 的小立方块输入网络。为了不遗漏可能出现在立方块边界的结节,每个立方块间重叠16 个像素。最后将得到的结果按立方块的位置还原到原来CT 图像中。在检测阶段,使用非极大值抑制(Non-Maximum Suppression,NMS)[30]得到最优的预测结果。
2.1.2 实验设置
本实验所用的操作系统为Ubuntu18.04 系统,CPU 为Intel(R)Xeon(R)Gold 6154,内存为64 GB,显卡为3 块显存为12 GB 的NVIDIA TITAN V 显卡。网络模型使用 Python3.6 搭建,框架为Pytorch1.1。优化方法选择随机梯度下降,初始学习率为0.01,权重衰减为1×10-4,批大小为18,模型迭代次数为100 次。为了加速训练,减小在损失全局最优点处的震荡,在训练过程中逐步减小学习率,具体为在训练的前1/3 轮时学习率为0.01,在1/3 到2/3 轮时,设置学习率为0.001,之后学习率为0.000 1。
由于LUNA16 数据集分为10 个子集,故训练与测试使用10 折交叉验证法。选择子集9 为测试集,子集0 到8 作为训练集,然后是子集8 为测试集,其余子集为训练集,依次类推。将所有子集的测试结果汇总,得到最终结果。
2.1.3 评价标准
实验选择无限制受试者操作特征(Free-Response Receiver Operating Characteristic,FROC)[25]验证模型性能。该曲线的横坐标为结节检测的假阳率,以0.125、0.25、0.5、1、2、4、8 这七个假阳率点为代表。纵坐标为结节检测的敏感度,即预测结节的正样本数占数据集总正样本数的比重。在FROC 曲线中,七个假阳率点上敏感度的平均值称为竞争性能指标(Challenge Performance Metric,CPM)[25]。计算公式为:
式中:s表示对应下标位置上的敏感度。
2.2 结果
2.2.1 方法对比
为了验证本文提出的模型检测肺结节的有效性,现将本文提出的方法与近年来具有代表性的肺结节检测方法进行比较。表1 记录了不同的检测模型在LUNA16 数据集上的检测结果,具体为在0.125、0.25、0.5、1、2、4、8 这7 个假阳性上的敏感度,以及它们的平均值(CPM)。
表1 与其他肺结节检测模型对比
从表1 可以看出,本文提出的模型在LUNA16数据集上的检测结果在目前具有代表性的肺结节检测模型中取得较好的结果,在0.125、0.25、0.5、1、2、4、8 这7 个假阳性上的平均敏感度大幅领先对比的其他同类型模型,仅次于张福玲等[28]在2021 年提出的模型。张福玲等提出肺结节检测模型在特征的解码路径上使用特征金字塔的方式将不同深度与尺寸的特征融合在一起,共同作为检测结果由预测层输出。本文提出的模型在特征解码后只将最后的特征图由预测层输出,虽然模型中有将不同深度与尺度的特征图进行融合的机制,从最后的结果看没有张福玲等提出的模型效果好。这表明该模型仍有需要提高之处。
2.2.2 结果可视化
为了更加直观展现肺结节检测,图5 展示密集残差检测模型对肺部CT 图像的检测结果。选择结节中心所在的横断面展示结果,以矩形边框定位结节。图5 中的第一行为检测结果中的真阳性结节,第二行为假阳性结节,假阳性结节具有与真阳性结节相似特征。从图中看出,本文提出的检测方法可以有效检测出多种大小与不同形状的肺结节,对在肺实质中不同位置处的结节均具有较好检测能力,对实性结节与部分实性结节检测结果较好。
图5 肺结节检测结果图
2.2.3 消融实验
为验证提出的跳跃连接信息增强机制、密集连接与注意力机制对肺结节检测性能的提升,在LUNA16 数据集上进行消融实验。在基线方法的基础上依次加入改进措施进行比较,数据处理与实验设置均同2.1.1 与2.1.2,实验结果如表2 所示。消融实验具体为:①model1:基线方法,普通残差3D UNet,网络层数同1.2.2,在编码解码部分间使用直接跳跃连接的模型;②model2:在model1 模型的编码解码部分间使用转置卷积信息增强,融合不同尺度的特征;③model3:在model2 模型的基础上在残差单元后加入注意力机制;④model4:使用本文提出的密集残差肺结节检测模型。
表2 消融实验对比
图6 所示为消融实验结果的FROC 曲线,方块线为model1 模型的FROC 的结果,圆形线为model2模型的FROC 的结果,上三角形线则是model3 模型的FROC 的结果,下三角形线则是model4 模型的FROC 的结果。
图6 消融实验的FROC 曲线
结合表2 与图6 可以看出,model2 相比model1在7 个不同的假阳率点处的敏感度均有提升,尤其在1、2、4、8 这4 个点上的提升较为明显,其平均敏感度提升了1.5%。这表明,在model2 跳跃连接处中新增的用于连接不同深度与尺度特征的信息增强机制对模型的性能起到了不小的提升。model3 相对于model2 在7 个不同假阳性点处的敏感度也有相当的提升,在0.25、0.5、1、2、4、8 这6 个点上提升明显,其平均敏感度提升了2.5%。这表明利用在模型中加入注意力机制可以有效提取肺结节的特征信息,进而提升检测模型的性能。model4 相对于model3 在0.125、0.25、0.5 这3 个点上提升明显,其平均敏感度提升了1.4%。这表明利用在模型中加入密集连接有助于提取肺结节的特征信息,提升检测模型的性能。
3 讨论
本研究提出了一个新的深度学习模型检测CT图像中的肺结节,模型以3D U-Net 网络为主干网络,引入密集连接、残差连接和注意力机制,以端到端的方式检测肺结节。模型在LUNA16 数据集上进行了一系列实验,实验结果表明,所提出的模型可以较好地检测CT 图像中的肺结节。
利用深度学习检测肺结节已取得许多进展[14],但大多数检测网络通过顺序增加卷积核增强网络对特征的提取能力,网络的结构与参数增加同时存在大量特征,导致训练难度增大[22]。本研究在基线模型中引入密集连接与残差连接,对网络中大量的肺结节特征复用,同时可以实现不同深度上的特征之间的信息交流,从而更加全面地提取肺结节特征信息。不同的特征图对于检测肺结节的重要性不同,为了更好地检测肺结节,提升模型检测能力,本研究在模型中引入注意力机制,通过对特征图一系列操作,对不同特征图赋予不同权重,增强模型对与肺结节有关的特征的提取。为了整体网络框架下实现不同深度与尺度特征的交流融合,在网络编码与解码部分间通过转置卷积与拼接操作实现网络不同深度与尺度间特征的交流,增强模型对肺结节的检测能力。为验证模型对肺结节的检测性能,在LUNA16数据集上,以10 倍交叉验证实验。在表2 中通过与基线模型的对比实验结果可以得出,所提出的几处改进均可增强模型对肺结节的检测性能,提升模型对肺结节的敏感度。表1 中通过与其他表现出色的肺结节检测模型的对比,所提出的密集残差肺结节检测模型可以对肺部CT 图像中的肺结节取得较好的结果。本研究提出的密集残差模型对肺结节的敏感度稍逊于张福玲等[28]提出的模型,主要在于其提出的模型中对肺结节特征使用特征金字塔机制,将不同尺度的特征图共同作为输出,这样可以融合不同层次的特征图,而本文提出模型只将最顶层特征作为输出,会忽略一些细节信息,这种设计可在之后的研究中探讨借鉴。
CT 图像的数据标注依赖于有经验的医师,这使得有标注的CT 图像数据收集困难,训练样本较少。迁移学习可以通过在大样本数据集上预训练得到具有一定特征提取能力的网络,在小样本的医学图像数据上小规模训练便可实现较好的检测能力。未来的研究将探讨通过迁移学习实现CT 图像数据的检测。
4 结论
本研究提出了一种密集残差肺结节检测方法,在3D U-Net 基础上引入密集连接、残差连接与注意力机制,较好实现了对CT 图像中的肺结节的检测与定位。大多数的肺结节检测方法,通过顺序堆叠卷积核构建网络,没有充分利用网络中的特征,同时不同层间的特征没有信息交流,对肺结节检测不够精确。对此,本研究在网络中引入密集连接、残差连接与注意力机制改进检测网络,同时在跳跃连接中融合不同尺度的特征。实验结果表明,本研究相对于原始方法的几处改进可以增强网络对肺结节敏感性,平均敏感度提高了5.5%,与其他结节检测算法相比,也取得了较好表现。
在未来工作中,针对有标注肺结节数据较少的问题,将结合迁移学习,将深度学习模型从其他领域图像数据中学习到的检测知识,运用到肺结节检测中,进一步提高结节检测精确度[34]。