基于3D-ResNet的阿尔兹海默症分类算法研究*
2020-06-22廖文浩
郁 松,廖文浩
(中南大学计算机学院,湖南 长沙 410075)
1 引言
阿尔兹海默症AD(Alzheimer’s Disease)是最常见的痴呆症,占痴呆症病例的60%~70%,其发病机制尚不明确,一般都是到了晚期才被发现,而且没有任何治疗可以阻止或逆转其进展,只能通过药物治疗来减缓认知衰退的进展[1 - 4]。快速增长的老龄人口和沉重的照护负担,成为当下社会不得不面对的紧迫问题。核磁共振成像MRI(Magnetic Resonance Image )是一种被广泛接受的非侵入性技术,其通常用于临床AD的辅助诊断[5,6]。然而,传统的AD诊断不仅耗时耗力,而且无法实现精确辨别,特别是对于早期阿尔兹海默症患者。此外,现有的方法或工具能够识别放射科医师无法轻易检测到的AD患者的细微变化,因此,迫切需要可信而有效的方法或工具来筛查正常人和AD患者。
随着深度学习技术的快速发展,特别是卷积神经网络CNN(Convolutional Neural Network)技术在图像处理和自然语言处理等领域取得的显著成效,越来越多的领域都开始尝试使用深度学习技术来解决问题。在深度学习领域,He等人[7]提出的ResNet模型很好地解决了神经网络的梯度消失或爆炸等问题,并且加快了网络模型的收敛。Selvaraju等人[8]尝试运用类激活映射CAM(Class Activation Mapping)技术来反向解释卷积神经网络,并取得了不错的成效。近年来,很多学者采用深度学习技术来进行AD分类。Fulton等人[9]通过50层的残差网络对核磁共振成像数据进行分类来预测阿尔兹海默症,在OASIS(Open Access Series of Imaging Studies)数据集上实现了98.99%的准确性。Khagi等人[10]构造了一个层数较浅的卷积神经网络模型对AD进行分类,并在OASIS数据集上实现了98.51%的准确性。Maqsood等人[11]通过预训练来微调卷积神经网络中的AlexNet模型,并使用微调后的模型对阿尔兹海默症的核磁共振影像进行分类,该模型采用未分割的图像进行训练和测试,并在OASIS数据集上进行评估,取得了92.85%的准确性。Luo等人[12]采用卷积神经网络模型对AD的核磁共振影像进行识别,实验表明该模型具有较高的AD识别精度,敏感性为100%,特异性为93%。Cheng等人[13]提出构建级联的3D卷积神经网络(3D-CNN),以分层学习多层图像特征,并将这些特征集合在一起来对AD进行分类,实验结果表明,所提出的方法对AD和正常对照组NC(Normal Control)的分类精度达到92.2%。Ju等人[14]采用深度学习对阿尔兹海默症进行早期分类,并建立有针对性的自动编码器网络,以区分正常衰老和轻度认知障碍(AD的早期阶段),所提出的深度学习方法可将预测精度提高约31.21%,并且在最佳情况下可使标准偏差降低51.23%。Payan等人[15]提出了一种深度学习方法,包括稀疏自动编码器和3D卷积神经网络,它可以基于MRI扫描预测患者的疾病状态,在预测AD大脑和健康大脑的数据时,达到了95%的准确性。Liu等人[16]提出了一种用于多类AD诊断的多模态神经影像特征提取模型,它开发了一个深度学习框架,使用零屏蔽策略来保留成像数据中编码的所有可能信息,实现了87%的高精度。Sarraf 等人[17]设计了深度学习模型用于区分阿尔兹海默症的MRI和fMRI与特定年龄组的正常人数据,其几乎完美地区分了阿尔茨海默证患者和正常人的大脑。 Hosseini等人[18]提出用3D卷积神经网络预测AD,该网络建立在3D卷积自动编码器上,该自动编码器经过预先训练,可捕获结构性脑MRI扫描中的解剖形状变化,在没有进行颅骨剥离预处理的MRI数据集上的实验表明,它的准确性优于几种传统的分类器。Vu等人[19]组合稀疏自动编码器和卷积神经网络用于多模态学习。
基于卷积神经网络的阿尔兹海默症的分类已然成为一大趋势。然而,现有的基于卷积神经网络的阿尔兹海默症分类方法大多只是简单地将在自然图像上分类效果较好的卷积神经网络模型迁移到阿尔兹海默症的分类,其模型并没有针对阿尔兹海默症的核磁共振影像的特点进行改进,且使用的数据集规模较少,因此基于卷积神经网络的阿尔兹海默症分类方法还有很大的提升空间。
本文主要研究的是基于3D-ResNet的阿尔兹海默症分类。通过分析现有阿尔兹海默症分类方法的不足,本文在2D-ResNet模型的基础上进行改进,并设计了一种3D-ResNet模型用于阿尔兹海默症的分类。首先,对阿尔兹海默症的核磁共振影像进行预处理,以生成核磁共振影像的灰质GM(Gray Matter)和白质WM(White Matter)的分割图;接着,用预处理后的核磁共振影像进行模型的端到端的训练和验证;最后,用训练好的模型进行阿尔兹海默症的分类。本文主要的工作如下:
(1)设计了一个3D-ResNet模型用于阿尔茨海默症分类。
(2)采用类激活映射(CAM)技术来可视化与AD相关的脑部区域。
2 3D-ResNet模型
2.1 问题描述与分析
现有阿尔兹海默症的分类方法大致可以划分为人工分类、半自动化分类和自动化分类3种。人工分类阿尔兹海默症的方式带有主观因素。当医生长时间观看影像时会感觉到眼睛疲惫,容易导致阿尔兹海默症患者的漏诊或误诊的现象发生。此外,医生通过肉眼分析阿尔兹海默症的核磁共振影像,对早期阿尔兹海默症的细微病理变化不敏感,很难及时检查出早期的阿尔兹海默症患者。基于机器学习的阿尔兹海默症分类的特征提取器是依据人工经验设计的,容易导致所提取的特征与输出任务弱相关,具有较强的主观因素,不能充分反映图像的本质特征,进而使最终分类的结果不够准确,产生漏诊或误诊的现象。相比基于人工的阿尔兹海默症分类方法和基于机器学习的阿尔兹海默症分类方法而言,基于卷积神经网络的阿尔兹海默症分类方法无需依据人工经验来提取特征,模型通过训练可以自动从阿尔兹海默症的核磁共振影像中提取特征,而且所提取的特征与分类任务高度相关,这也是本文选择采用卷积神经网络技术对阿尔兹海默症的核磁共振影像进行分类的主要原因之一。
由于卷积神经网络中的ResNet模型的收敛速度快和收敛效果较理想,因此本文在阿尔兹海默症的分类问题上引入卷积神经网络技术中经典的ResNet分类模型。在基于残差模块的ResNet模型的基础上进行修改是为了避免模型训练过程中出现梯度消失问题,如式(1)和式(2)所示:
∂Loss/∂X1=∂FL(XL,WL,bL)/
∂XL*…*∂F2(X2,W2,b2)/∂X1
(1)
∂XL/∂Xl=(∂Xl+∂FL(XL,WL,bL))/∂Xl=
1+∂FL(XL,WL,bL)/∂Xl,l∈(1,L)
(2)
其中,∂代表求导符号,XL代表第L层的特征图,WL代表第L层的权重参数,bL代表第L层的偏置参数,FL(XL,WL,bL)表示XL,WL,bL组成的函数。
当网络模型非常深的时候,越往前传的梯度值就越小,趋近于0,不采用残差模块的模型容易出现梯度消失的问题,而采用残差模块的模型永远不会产生梯度消失的问题,因为回传的梯度值在1左右。此外,深度神经网络通常比浅层神经网络的学习能力更强。只是简单地堆叠卷积层以构建深度神经网络不是一种行之有效的方法,因为随着神经网络中的层数增加,梯度的传播将受到显著阻碍。然而,基于残差模块的ResNet模型的深度增加不受约束,这样可以大大提升模型的拟合能力,而且采用残差模块的ResNet模型的损失函数表面非常光滑,如图1所示。其中,没有采用残差模块的模型损失函数(如图1a所示)表面非常崎岖,这样很容易使模型收敛于局部最优值,从而导致模型的性能较差。采用残差模块的模型损失函数(如图1b所示)表面非常光滑,这样既能加快模型的收敛速度,还能使模型收敛得较好。
Figure 1 Loss function surfaces of the models图1 模型的损失函数表面
2.2 3D残差模块
ResNet模型最先应用于2D图像分类,然而,对于像MRI这样的3D医学图像数据,空间信息在阿尔兹海默症的分类中起着重要作用。3D深度学习模型可以从3D MRI图像中更好地识别解剖位置和病理特征。为了充分利用3D MRI中的上下文信息,本文扩展了原始ResNet支持体积数据的能力,使用体积级3D-ResNet代替切片级2D-ResNet。具体来说,本文改进了原始ResNet模型的残差模块,将所有2D版本的卷积层、批量标准化层和池化层调整为3D版本的,如图2所示,图2a为原始的2D残差模块,图2b为本文提出的3D残差模块。3D残差模块的流程如算法1所示,其中,σ表示采用激活函数进行非线性操作,BN表示归一化操作。
Figure 2 Residual module图2 残差模块
算法13D残差模块
输入:特征图x。
输出:特征图xr。
步骤13D的1*1*1卷积操作:通道的降维和信息融合。
x11=σ(BN(conv1*1*1(x)))
步骤23D的3*3*3卷积操作:特征提取。
x3=σ(BN(conv3*3*3(x11)))
步骤33D的1*1*1卷积操作:通道的升维和信息融合。
x12=σ(BN(conv1*1*1(x3)))
步骤4残差特征图和输入特征图的信息融合,特征图的逐像素相加。
xr=x+x12
2.3 网络结构
本文以3D残差模块为基础,构建了一种3D-ResNet算法用于AD分类,如图3所示,其主要由预处理和3D-ResNet-101 2部分组成,其中101表示模型有101层。预处理部分负责将原始的166*256*256 MRI图像分割成121*145*124的MRI灰质图像和121*145*121的MRI白质图像;3D-ResNet-101部分主要由一个7*7*7卷积层、一个3*3*3的池化层、一系列对应不同特征通道数的3D残差模块和全连接层组成。3D-ResNet-101部分将预处理后的灰质图像和白质图像叠加在一起进行端到端的训练,其中*3,*4,*6,*3分别代表3D残差模块的个数,并以二分类的形式给出模型的预测结果,NC和AD分别表示正常人和阿尔兹海默症患者的分类结果。3D-ResNet算法如算法2所示。
Figure 3 Classification model structure for Alzheimer’s disease based on 3D-ResNet-101图3 基于3D-ResNet-101的阿尔兹海默症分类模型的基本结构
算法23D-ResNet算法
输入:特征图x。
输出:分类结果NC或AD。
步骤13D的7*7*7的卷积操作:特征提取。
x1=σ(BN(conv7*7*7(x)))
步骤23D的3*3*3的池化操作:减小特征图分辨率。
x2=pooling3*3*3(x1)
步骤3输出残差模块操作得到的特征图。
步骤3.1输出经过3个残差模块操作的特征图。
x3=R(x2)*3
步骤3.2输出经过4个残差模块操作的特征图。
x4=R(x3)*4
步骤3.3输出经过23个残差模块操作的特征图。
x5=R(x4)*23
步骤3.4输出经过3个残差模块操作的特征图。
x6=R(x5)*3
步骤4用3D的全连接操作进行分类。
c_vector=(PNC,PAD)=FC(x6)
步骤5取c_vector向量最大概率值所在下标对应的类别。
class=max(c_vector)
3 实验
3.1 数据集
本文使用的数据来自阿尔茨海症神经影像学倡议(ADNI)数据集,ADNI的主要目标是测试是否可以将连续核磁共振影像(MRI)、PET、其他生物标志物以及临床和神经心理学评估结合起来,以测量轻度认知障碍和早期AD的进展。ADNI数据集主要包含3T类型的MRI数据和1.5T类型的MRI数据。
本文实验选用1.5T类型的MRI数据,其是由639人的扫描数据组成,其中一个人有可能有多次扫描数据,扫描的数据类型分为正常人(NC)数据、轻度认知障碍(MCI)数据和阿尔茨海默证患者(AD)3种,如表1所示,ADNI数据集中的1.5T类型的MRI数据共有3 299个,其中包含1 015个CN数据,1 709个MCI数据,575个AD数据。
Table 1 1.5T type MRI data in ADNI dataset表1 ADNI数据集中1.5T类型的MRI数据
本文从ADNI数据集中的1.5T类型的AD和NC的MRI数据中随机选取了1 163个数据,其中包含575个166*256*256的AD数据和588个166*256*256的NC数据,如表2所示。NC和AD数据的个数比趋近于1∶1,这是因为拟合的模型往往会倾向于数据量多的那一类数据,因此将各类数据设置成1∶1可以避免这种数据倾斜问题的发生。
Table 2 Data used in the experiment表2 实验中使用的数据
本文将整个数据随机分成3部分:训练集(64%)、验证集(16%)和测试集(20%),如表3所示。其中训练集的数据量∶验证的数据量∶测试集的数据集量约为6∶2∶2,按照这个比例来划分数据集是为了让模型在训练集上得到充分的训练,在验证集上进行实时的监控,在测试集上进行实际的预测。
Table 3 Partitioning of the experimental data表3 实验数据的划分
3.2 数据预处理
本文采用的MRI数据均使用基于Matlab的SPM12软件包进行预处理。本文引入了基于SPM12的标准化CAT12(用于SPM的计算解剖工具箱)包用于数据预处理,并选取CAT12包中的121*145*121的templates_1.50 mm图像作为数据预处理的模板,选取CAT12包中的121*145*121的neuromorphometrics图像作为数据预处理中区域分割的标签模板。其中标签模板将脑部一共划分为142个脑区,选取DPABI 2.2包中的ch2图像作为原图模板,并将其尺寸从181*217*181重置为同templates_1.50 mm的尺寸121*145*121相同大小,进而用于显示与CAM技术生成的mask图像的叠加。预处理的步骤包括空间配准、组织分类(灰质、白质和脑脊液)和强度非均匀性的偏差校正。本文将预处理后AD病患和正常人生成的灰质(GM)和白质(WM)数据相叠加用于3D-ResNet-101模型的训练、验证和测试。如图4所示,输入的原始MRI图像的尺寸为166*256*256,输出的预处理后的MRI图像尺寸分别为121*145*121的灰质图像和121*145*121的白质图像。
Figure 4 Preprocessing of magnetic resonance imaging图4 核磁共振影像的预处理
3.3 模型的训练
本文所设计的3D-ResNet算法采用的初始化方法为kaiming initialization[20],且算法中的参数都是从头到尾训练,所有的实验都是在一块Tesla K40 GPU 上完成的,GPU的显存约为11 GB。模型训练的超参数如表4所示,其中模型总的迭代次数设置为50,批量训练的大小为6,学习率从0.001开始,每迭代20次后,学习率除以10。
Table 4 Hyper parameters for model training表4 模型训练的超参数
3.4 结果与分析
为了评估所设计的3D-ResNet-101模型对AD分类的性能,本文将其与其他模型进行了准确性(accuracy)的对比,如表5所示。accuracy的定义如式(3)所示,其反映了模型准确分类AD病例和正常人的数量,accuracy越大,模型正确分类AD病例和正常人的数量就越多。
(3)
其中,TP、TN、FP和FN分别代表真阳性、真阴性、假阳性和假阴性。
Table 5 Comparison of accuracy among different models表5 不同模型的准确性对比
为了进一步评估所设计的3D-ResNet-101模型对AD分类的性能,本文评估了基于3D-ResNet-101模型的敏感性(sensitivity)和特异性(specificity)指标,如表6所示。sensitivity的定义如式(4)所示,其反映了模型准确分类的AD病例数量,sensitivity越大,漏诊的AD病例就越少。specificity的定义如式(5)所示,其反映了模型准确分类的正常人的数量,specificity越大,正常人被误诊为AD的病例就越少。
(4) Table 6 Sensitivity and specificity results of 3D-ResNet-101 model on verification set and testing set表6 3D-ResNet-101模型在验证集和测试集上的敏感性和特异性
(5)
本文还探讨了模型训练参数不同对分类结果的影响。当初始学习率参数设置为较大的数0.1时,模型训练过程中会出现收敛抖动的现象,从而导致模型收敛时间延长或无法收敛。由于实验条件的限制,批量处理的大小最大可设置为6。当批量处理大小小于6时,模型最终同样会收敛,但所需要的收敛时间会延长。
类激活映射(CAM)技术是将最后一个卷积层的输出特征图和该特征图对应的输入图像的仿射矩阵相结合,进而生成与输入图像尺寸相同的掩码图,如图5所示,其中响应越高的图像区域,越能代表3D-ResNet算法的分类依据。
Figure 5 Structure of class activation mapping图5 类激活映射技术的基本结构
此外,本文将原始的MRI图像进行预处理,并生成相应的灰质图像和白质图像,然后采用CAM技术生成掩码图像,并将掩码图像中的高响应区域可视化到标准的脑模板上,如图6所示,其通过Matlab 2016b的DPABI Viewer来展示叠加效果,并从axial、coronal和sagittal 3个不同的方向进行具体显示。
3.5 讨论
为了探索3D-ResNet算法的深度对AD分类的影响,本文对比了3D-ResNet-50模型和3D-ResNet-101模型在训练过程中accuracy和loss的变化情况,结果如图7所示。其中3D-ResNet-101模型比3D-ResNet-50模型拟合目标数据的效果更好,由此可以推出适当增加3D-ResNet算法的层数能提升模型的拟合能力。
Figure 6 Visualization of Alzheimer’s disease classification based on 3D-ResNet图6 基于3D-ResNet的阿尔兹海默症分类的可视化
Figure 7 Model depth impact on Alzheimer’s disease classification图7 模型的深度对阿尔兹海默症分类的影响
Figure 8 Acquisition of 2D MRI slices图8 2D MRI切片的获取
此外,本文通过取MRI图像的中间切片将3D图像转成2D图像的形式,以便于采用2D神经网络模型做AD的分类。以3D单通道灰质图像转2D单通道灰质图像为例,图8显示了本文将3D MRI图像转为2D图像的原理,其中(37,44,37)表示分别从sagittal、coronal和axial 3个方向取121*145*121的第37个2D图像和第44个2D图像、第37个2D图像,同理,类似的还有(61,73,61)和(97,116,97)。由于(61,73,61)位置的切片能将脑部轮廓全部展现出来,因此本文选取的2D图像就是从该位置获得的,并将获取的2D图像重置为28*28的单通道图像。本文将axial方向的GM和WM对应的2D图像叠加在一起,然后输入到2D神经网络模型进行训练进和预测。本文探讨了使用体积级3D-ResNet代替切片级2D-ResNet的区别,如表7所示,在模型参数和训练时间上,体积级3D-ResNet的模型都比切片级2D-ResNet的模型增加了很多,但在准确性上,体积级3D-ResNet的模型却比切片级2D-ResNet的模型提升将近一倍。
Table 7 Difference between volume-level 3D-ResNet and slice-level 2D-ResNet表7 使用体积级3D-ResNet代替切片级2D-ResNet的区别
4 结束语
本文提出了一种3D-ResNet模型用于准确而快速地对AD进行分类,并用accuracy、sensitivity和specificity性能指标对3D-ResNet模型进行了评价。针对AD的发病机制尚不明确的问题,本文通过类激活映射(CAM)技术来可视化与AD相关的脑部区域。此外,本文的模型还可用于诊断其他类似的神经疾病,如帕金森和小儿麻痹症等。未来我们希望将本文所提出的基于3D-ResNet的AD分类算法应用于实际生活中,并从精度和速度上进一步优化。