基于残差自注意力机制的阿尔茨海默症分类*
2022-02-03卢星进王书卜肖磊高礼彬李瑞胡众义
卢星进,王书卜,肖磊△,高礼彬,李瑞,胡众义△
(1.温州大学 计算机与人工智能学院,温州 325035;2.温州市智能影像处理与分析重点实验室,温州 325035;3.上海东方医院,上海 200120)
引言
阿尔茨海默症(Alzheimer′s disease,AD)是一种发病进程缓慢且不可逆的神经退行性疾病,可导致患者大脑神经细胞死亡和组织萎缩[1-2]。据报道,目前AD约占所有痴呆病例的60%~70%,有超过5 500万人患有AD,预测到2030年将达到7 800万,治疗总成本将达到2万亿美元。尽管目前对AD患病的原因尚不明确且缺乏特效药物来阻断发病、治愈AD,但是若对AD进行早期诊断和治疗即可有效延缓其恶化。在临床诊断中,磁共振成像技术、弥散张量成像技术和脑电图等方式能够展示患者大脑的结构变化,其中,结构性磁共振成像技术(structural MRI,sMRI)产生的图像能够更好地显示器官的萎缩、变形等。因此,本研究基于sMRI 图像设计残差自注意力框架以实现对AD的辅助诊断。
1 相关工作
在计算机辅助诊断领域,近年已有多项研究探讨阿尔茨海默症,以实现对该疾病的计算机辅助诊断。黎建忠等[5]提出了一种基于三类特征的支持向量机模型,三维重构 sMRI 的灰质体积、皮层表面积及厚度作为特征输入进行AD和正常对照组(normal controls,NC)的分类识别。曾安等[6]将大脑脑区预处理的感兴趣区域作为特征,输入3D卷积神经网络用于分类任务。Agrawal等[7]使用深度卷积神经网络用于AD的分类。曾安等[8]提出了基于卷积循环神经网络的AD早期诊断框架,该方法结合2D卷积和循环神经网络,将2D切片序列输入框架,对AD和NC的分类准确率达到93.3%。Dai等[9]将功能连接网络(FCN)和结构连接网络(SCN)的拓扑结构相关信息应用于AD分类,具有较好的效果。曾安等[10]利用脑区模板标签(AAL)划分受试者的大脑区域,提出直映式注意力机制,提高了模型的稳定性及准确率。Liu等[11]提出构建级联卷积神经网络,学习MRI和PET脑图像特征用于AD分类。Shi等[12]建立光滑非线性特征空间变换来映射数据,并使用深度网络特征融合策略对脑图像的横截面和纵截面进行融合,评估特征转换和特征融合策略的有效性。曾安等[13]设计基于卷积神经网络和集成学习的切片集成分类模型用于AD早期分类。潘丹等[14]提出基于3D卷积神经网络和遗传算法相结合的AD早期辅助诊断模型,首先利用3D卷积对数据训练出候选基分类器,然后利用遗传算法选出最优基分类器组合,最终集成分类。Jain等[15]采用基于迁移学习的模型,训练VGG作为特征提取器,用于AD和NC的分类任务。Qiu等[16]提出一个可解释的深度学习策略,框架中使用全卷积网络,从MRI、年龄和精神状态检查分数等多模态数据中描绘AD特征,得到AD的准确诊断。Mehmood等[17]开发了一个Siamese卷积网络(SCNN)模型,并在AD和NC的分类中得到了99.05%的准确率。Xing等[18]创新性地改造了近似秩池化操作,将三维MRI图像转为二维图像,作为二维卷积输入,相比于三维模型,其AD分类准确率提高了9.5%。Esteva等[19]使用强化学习方法,达到了对医学成像数据的精准分类。基于关键切片投票的方法[20]从关键切片区间的角度,利用切片投票的分类结果,诊断患病状态,通过多张切片集成分类。虽然该方法的准确率和稳定性都有一定程度的提升,但是仍有未进行切片权重分配,进而模糊核心切片区间比重的问题。
针对上述问题,本研究受自注意力机制[21]和集成学习的启发,在关键切片投票方法[20]的基础上,使用残差网络提取图像特征信息,再通过切片自注意力机制替代切片投票机制学习特征信息,并进行权重分配,使得每一张切片拥有动态的权重,达到增强核心分类区间比重的作用,最终在不同权重的切片下集成分类结果,可智能诊断AD和NC,达到精准分类。通过将自适应梯度裁剪应用于图像分类模型框架,在训练大批量和大规模的数据时,更容易达到收敛,提升了模型诊断的可靠性。
2 数据集与预处理
2.1 数据获取
本研究所有数据均来自阿尔茨海默病神经影像学计划(Alzheimer′s disease neuroimaging initiative, ADNI)数据库(www.loni.ucla.edu/adni),从中下载431名受试者的MRI图像,其中AD 200例,NC 231例,年龄分布在70~84岁,性别比例保持均等。详细信息见表1。
表1 数据集基本情况
2.2 数据预处理
利用FreeSurfer软件进行图像预处理,主要包括头部矫正、磁场不均匀信号偏差矫正、MNI305模板配准、白质区域信号的标准化和非脑组织(比如颅骨、胫部等)的去除,处理过程中的参数使用默认参数,经过FreeSurfer软件预处理后的图像大小均为256×256×256,再进行粗裁剪去除边缘无用的非脑区域,重采样得到图像大小为180×180×150。预处理流程见图1。
图1 数据集预处理Fig.1 Dataset preprocessing
在之前切片工作[20-22]中,使用预处理完的数据分别测试了从独立测试个体的沿轴向、冠状和矢状解剖平面提取的切片,表明冠状面具有最高的准确率和敏感度,因此本研究实验最终选取40个冠状面切片用于实验,其索引区间为[70,109]。最终在431个样本中生成8 000张AD患者的MRI切片,9 240张NC的MRI切片。
3 方法
本研究基于残差自注意力机制,使用残差网络[23]提取每个切片图像特征信息,利用自注意力机制学习特征信息,并进行权重分配,增强核心分类区间的比重后,使用不同权重的切片特征信息集成分类。同时,在模型中增加自适应梯度裁剪(adaptive gradient clipping, AGC)[24],以便于达到收敛。
医生阅片时,会依据大脑切片的图像判断患者是否患病,但是在不同位置的切片会影响医生决策。因此,引入自注意力机制,对不同切片位置设置权重,并进行动态分配,最终能够增强核心切片区间的权重,进而实现集成分类。
基于残差自注意力机制的框架中,残差网络的作用是学习每张冠状切片图像,并提取对应切片位置的特征信息;自注意力机制则是学习切片间的特征信息比重,对权重动态分配,增强核心分类区间的比重。最终将所有切片权重信息集成,得到AD分类结果。该框架很好地提高了视觉信息处理的准确性。具体模型框架见图2。
图2 残差自注意力框架Fig.2 Residual self-attention framework
本研究使用的框架结构分左、中、右三部分。其中,左边表示sMRI数据集;中间绿色部分表示40个sMRI切片,蓝色部分表示残差模块;右边橙色部分表示经过残差网络后提取的特征信息,绿色部分表示值向量序列,蓝色部分表示键向量序列,灰色部分表示查询向量序列,红色部分表示权重向量。
3.1 残差网络
残差网络(residual network,ResNet)采用恒等映射结构,减轻了普通卷积神经网络的训练难度。ResNet由卷积层、池化层等构成的多个残差块堆叠组成。特征矩阵以两个分支进入残差块,直线分支进入多个卷积层输出特征矩阵。跳跃分支在输入输出一致时使用恒等映射,不一致时使用线性投影保证维度一致。最终对直线分支和跳跃分支输出特征的矩阵进行求和并进入下一层。残差学习算法见式(1)。
(1)
本研究主要基于ResNet-34模块实验,ResNet-34模块有34个堆积层结构,先通过步长为2的池化进行下采样,再通过中间16层残差块,最后网络以全局池化层和softmax函数的全连接层输出分类结果。不同残差块层的输入和输出有差别,结构基本相似。使用ResNet-34可以减缓过拟合或者梯度消失/爆炸问题,拥有更好的性能。
3.2 自注意力机制
自注意力机制采用表征加权的方式有助于获得特征的内部相关信息。自注意力机制可以建立序列依赖关系,通过依赖关系引入权重,在通道和空间层面计算每个单元通道与通道、像素点与像素点之间的值,以加强两者之间的联系。其中权重计算引入的Q,K,V分别为查询向量序列、键向量序列和值向量序列,分别学习了参数矩阵,得到自注意力权重大小。自注意力权重算法见式(2)。
(2)
其中,j表示第j层冠状切片,L是线性映射用来匹配维度。Kj,Qj和Vj分别表示注意力机制的查询,键和值向量矩阵,ω表示激活函数。
本研究在引入自注意力机制时会对40个切片产生40个不同权重信息,这些权重信息构成的注意力矩阵见图3。其中每个格子的颜色代表权重值,颜色越深权重值越大,具体数值用图中热力棒刻度表示。本研究在预处理中选取切片索引为[70, 109],将第70个索引作为第1个格子,按照顺序依次分配给40个格子。其中横坐标表示格子坐标的个位数,纵坐标表示格子坐标的十位数。由图可知,越靠近切片索引中间位置特征信息越稳定,权重值越大。
图3 注意力矩阵
3.3 自适应梯度裁剪
本研究中,实验设置的自适应梯度裁剪是解决梯度爆炸的关键技术,可替代神经网络中归一化层用于残差网络。传统归一化的方法不适用于大批量训练且难以选择最佳学习率。自适应梯度裁剪的方法,基于梯度范数与参数范数的单位比例来进行剪切梯度。在数据规模比较大时,使用自适应梯度裁剪能够使得寻找最优解的过程变得平缓,模型更容易达到正确的收敛水平。本研究使用AGC替代批归一化层用于网络中,在寻找平缓最优解的基础上,能训练大批量和大规模的数据快速收敛并取得稳定的准确率。
4 实验
设置训练集和测试集比例为8∶2,选取对于分类效果最佳的区间[70, 109]。为确保对比公平性,采用相同的参数配置。最终实验使用ResNet-34模块作为基础网络,采用RMSprop优化器,α:0.99,学习率:0.0005,偏置:1e-8,实验设置100个迭代次数,在最终无变化时,停止训练。经对比发现,VGG和AlexNet网络结构较为臃肿不适合修改。在本节中,基准实验使用ResNet-34模块作为基础网络,再结合关键切片投票方法。在基准实验上改进关键切片投票方法,使用自注意力机制动态分配权重,最终实验基于残差自注意力机制框架。该框架分别与不同网络进行对比,分类性能参数包括准确率(ACC),召回率(REC),精确率(PRE)和均衡平均数(F1-Score)。结果见表2、表3。
表2 不同网络结果对比
表3 不同模块结果对比
4.1 不同网络效果
本研究提出的模型四个方面均优于PCA+SVM和传统神经网络的框架,其中PCA+SVM代表基于主成分分析和支持向量机的诊断模型[13];2DCNN表示使用单切片训练的二维卷积神经网络[13];ResNet是基于ResNet-34的实验设置;3D-PCANet为文献[14]中提出的一种非监督学习方法;VGG by slice代表文献[20]基于关键切片投票的方法。由表2可知,基于残差自注意力机制框架可以提高传统神经网络的性能,同时也能够有效捕捉关键特征信息。
4.2 不同模块效果
表3为本研究模型与不同模块在同一主干网络下相对比,实验结果显示,本研究在准确率、召回率、精准度和均衡平均数均优于基准实验,准确率比基准实验高2.4%,其中ResNet是基于ResNet-34的实验设置;Baseline是基准实验。这表明基于残差自注意力机制的诊断方式能够达到更好的分类效果。该方法通过捕捉关键特征,学习特征间的显著性,动态分配权重信息,在增强核心分类区间比重的方式上,达到不同权重集成分类。最终框架对AD分类具有较高的敏感度和鲁棒性,该模型能够更好地提升AD分类的性能,为辅助医生诊断AD提供客观依据。
5 结论
本研究提出了一种基于残差注意力机制的分类模型。该模型使用残差网络提取每张切片图像的特征信息,然后通过自注意力机制学习提取的特征信息,对切片权重动态分配,增强核心分类区间的比重,最终集成所有的权重结果,对AD和NC进行分类。实验结果在对比不同主干网络和不同模块方面都取得了良好效果。基于残差自注意力机制的召回率、精确率和均衡平均数都高于基准方法,对模型的改进有很好的效果,同时对于AD的辅助诊断也具有重要意义。在后续的研究中,将会尝试使用更多维度的信息(比如表征信息等)来提取特征,结合自注意力机制进一步提高模型性能。