基于对抗域适应的心电信号深度学习分类算法
2023-03-21蒋思清陈潇俊高豪俊何佳晋
蒋思清,陈潇俊,高豪俊,何佳晋,吴 健,
(1.浙江大学公共卫生学院,浙江 杭州 310058;2.浙江大学睿医人工智能研究中心,浙江 杭州 310000)
0 引言
在临床上,心血管疾病常伴有心律失常症状,严重的心律失常可能会导致猝死或心力衰竭的发生,因此,及时、准确地检测心律失常是十分必要的[1]。心电图提供了丰富的心脏健康和病理信息,是诊断心脏疾病的重要方法[2]。心律失常情况通常要由医生给出诊断,但在临床实践中,受到医生经验的差异性和严重的心电噪声影响,往往会出现误诊和漏诊[3]的情况,因此需要找到一种识别带有噪声的心电信号,即可实现对心律失常类别进行准确判断的方法,进而帮助医生更精准地发现心律失常事件。此外,在人工智能技术迅速发展的背景下,目前计算机辅助心电图分析的方法[4]越来越受到大家的关注,成为了心电图领域的研究热点之一,其中高精度的心电图自动诊断在心血管疾病的预防和辅助诊疗中起到了关键作用,具有良好的实践前景和重要的医学价值。
目前,临床心电图在计算机领域的分类方法包括2 个大类,即基于信号处理技术的方法和基于深度学习的方法[5]。第一类信号处理方法包括小波特征、高阶统计量、功率谱特征等通过提取特征向量的方式针对波形的振幅和频率进行分析[6-11]。例如Martis等[12]对正常、心房颤动和心房扑动心电图信号进行了独立成分分析(Independent Component Correlation Algorithm,ICA),其分类效果显著高于K 邻近分类 器;Acharya 等[13]提 出 了 一 种 计 算 机 辅 助 诊 断(Computer Aided Diagnosis,CAD)方法,他们从心电信号中提取了熵特征并通过决策树对14 个重要特征进行分类;Ye等[14]采用形态学和动态特征方法,检测准确率达到86.4%。虽然上述传统的分类方法取得了良好的结果,但从心电信号中提取关键特征并构建心电特异性特征向量的过程复杂且耗时,其次,分类效果与特征的选择高度相关,特征的选择容易受到主观因素的影响,加上过拟合问题[15]的出现,在实际应用中具有一定的局限性,难以达到预期的诊断精度。相比之下,第二类深度学习方法比传统方法更具优势。它们可以优化心电信号特征提取的过程,取得更好的分类性能和泛化能力。Yildirim 等[16]采用长短期记忆神经网络(Long Short-Term Memory,LSTM)对心电进行识别和分类;Chu等[17]首先提出了一种针对多导联心电信号的二维卷积神经网络(Convolutional Neural Networks,CNN)来提取交叉导联心电图特征,将CNN 和LSTM 提取的特征与传统特征相结合,采用二值粒子群优化算法(Discrete Binary Particle Swarm Optimization Algorithm,BPSO)对特征进行区分和选择,最后选择加权支持向量机作为分类器时效果表现良好;Sellami 等[18]采用Resnet 对心电信号进行自动分类,并提出了批量加权损失来处理数据不平衡的问题,每个输入数据包含2 个心拍以便更好地学习特征,总体分类准确率达到89%。
随着可穿戴技术的发展,心电数据的采集更加方便、高效[19],但由于数据标注的成本较高,许多普通医院无法支持相应设施,导致许多标注不完整,无法获得足够的优质训练数据。心电图数据的分布由于个体差异以及数据采集来源和方法不同也存在差异,而迁移学习中的领域自适应方法在处理标记较少的数据及其分布差异方面具有优势,因此可以将迁移学习[20]引入心电信号分类任务中。本文重点研究基于迁移学习的自动分类技术,采用大量标记样本作为源域数据,少量未标记样本作为目标域数据,利用域适应提高域间迁移能力,多尺度特征提取器则有助于更好地学习信号的复杂特征,提高模型分类性能。基于对抗域适应来解决标记不完整和个体变异问题,并对8 类常见的心律失常进行智能分类的方法应用于临床心电领域的相关研究在国内外鲜有报导。
本文的3个主要工作如下:
1)提出一种基于对抗域自适应学习的心电信号分类方法,解决标记训练样本不足的问题,改善个体差异导致的数据分布差异的现象。
2)对该方法的A、B、C这3大模块各组成部分分别进行优化,进而实现对临床上8种常见的心律失常类别的精准分类,为心电图临床决策和辅助诊断提供依据。
3)提出4 个关键的时间特征,并将其与深度学习特征进行串联,增加特征的丰富性。
1 数据分析与预处理
心电信号分类任务流程如图1 所示。首先对心电记录中的信号进行预处理,生成输入数据,然后将其传递到对抗域自适应模型中进行训练,将自动提取的模型特征与人工提取的时间特征进行融合并输入到分类器中,得到最终的分类结果。
图1 心电信号分类任务流程图
1.1 心电图数据介绍
在本文中使用的数据集和注释来自麻省理工学院MIT-BIH 心律失常数据库。该数据库共48 条0.5 h(导联数为2)的心电信号记录,每组信号以360 Hz来进行采样,单位电压5 μV[21]。其中未明确分类搏动的研究意义相对较少,不能作为判断分类结果的依据,所以在心电研究中将其剔除。
根据MIT-BIH 心律失常数据库提供的注释,较主要筛选出8 种临床心率失常类别,分别是正常窦性心律NOR(Normal Sinus Rhythm)、左束支传导阻滞LBBB(Left Bundle Branch Block)、右束支传导阻滞RBBB(Right Bundle Branch Block)、房性早搏APB(Atrial Premature Beat)、室性早搏PVC(Premature Ventricular Contraction)、步 速 跳 动PAB(Paced Beat)、室性逃逸跳动VEB(Ventricular Escape Beat)、心室颤振VFW(Ventricular Flutter Wave)。各类别记录情况和对应的心拍数如表1所示。
表1 MIT-BIH心律失常数据库的心电图信号描述摘要表
将MIT-BIH 心律失常数据库中的48个记录进一步分为数据集I 和数据集Ⅱ这2 组数据集,其中数据集I 包含的心电图记录编号分别为101、102、104、106、108、109、112、114、115、116、118、119、122、124、201、203、205、207、108、109、215、220、223、230;数据集Ⅱ中包含的心电图记录编号分别为100、103、105、107、111、113、117、121、123、200、202、210、212、213、214、217、219、221、222、228、231、232、233 和234。本实验中,2 组数据集样本在各类心律失常标签数的具体分布情况如表2所示。
表2 MIT-BIH数据库的心律失常各类标签数目分布情况
1.2 数据预处理
原始数据在传递到模型进行训练之前需要进行预处理,如图2 所示。根据数据库中已有的R 波峰值标记对每个心电信号进行定位。由于不同个体相邻R 峰之间的平均间隔不同,使用固定数量的数据点分割心跳的方法会错过心跳的一些重要信号特征,因此,本文采用下面提出的心拍分割的方法来解决这个问题:
图2 数据预处理过程图
1)基线漂移处理。对含有波形突变的心电基线漂移信号进行平滑处理,并从原始记录数据中消去平滑信号,以消除基线漂移噪声BW(Baseline Wander Noise)。
2)工频、肌电去噪。采用截止频率为(0.5,40)Hz的带通滤波器和以db5 为基函数的离散小波变换DWT(Discrete Wavelet Transform),以消除信号电极伪影噪声EMG(Electromyography Artifact Noise)、肌肉伪影噪声MA(Muscle Artifact Noise)。
3)心拍分割。首先,读取心跳标签中提供的各心电波R 峰位置。假设Ri为第i个心跳R 峰值的位置。定义<N>表示将N四舍五入到整数,心跳的起始位置
4)心拍格式标准化。每次心跳的采样点Hi个数不同。进入深度学习模型的前提是心拍长度必须一致。假设统一后的采样点数为D(本实验中D为400),如果Hi小于D,则填0 到D,如果是Hi大于D,则裁剪至D[22],最终处理后的心跳为Hf,将对齐后的心跳H,用公式(其中,μ表示所有心跳(心拍)的均数图,σ表示标准差意为所有心拍的离散程度)进行标准化,以消除信号中的偏移和幅度缩放问题。
5)时间特征提取。人工提取当前心跳的前RR间隔、后RR间隔、局部10 s的RR间隔和平均RR等4个时间特征,并利用公式(其中,max(PRR)表示前RR 间隔,min(PRR)表示后RR 间隔,PRR表示局部10 s 的RR 间隔,平均RR 用PRR表示)生成归一化前后的RR时间特征。
6)数据增强。使用SMOTE 算法[23]生成来克服不同类别样本数量不平衡的问题。
2 模 型
2.1 模型架构
假设模型适用于输入样本,样本标注可表示为y∈Y,并进一步假设存在分布S(x,y)和T(x,y),分别表示源域和目标域的分布,假设这2 种分布相似但不同,目标是从目标分布的输入x来预测标签y。训练过程采用源域分布和少量的目标域进行,表示为{X1,X2,X3,…,XN},定义di为第i条数据样本的域标签,确定Xi是来自源域还是目标域,di=0表示来自源域;di=1表示来自目标域。所提出的对抗域自适应模型实现原则如下:定义一种改进的对抗域自适应模型,预测每个输入x对应的标签y∈Y及其域标签d={0,1}。对抗域自适应模型如图3 所示。所构建的对抗域自适应模型包括3 个模块:多尺度特征提取模块A、域识别模块B和分类模块C。
图3 对抗域自适应模型
2.1.1 多尺度特征提取模块A
为了获得足够的特征,对A模块进行了改进,将原来由一组卷积块组成的单一特征提取结构扩展为由2组具有不同卷积核的并行卷积块组成的多个特征提取结构。,将1.2节产生的输入数据Hf代入模型,提取2组不同的特征,并将其串联为特征f,即具体的结构如图4 所示,其中k表示卷积核的大小。表3 总结了多尺度特征提取模块A 体系结构和每一层的输出情况,其中为防止模型过拟合,采用Dropout函数来消除减弱神经元节点间的联合适应性,增强泛化能力,ReLU表示激活函数,Maxpool表示池化层。
图4 多尺度特征提取模块A
表3 模块A体系结构
2.1.2 对抗域自适应模块B
针对原始模型层数少、特征提取少的问题,对B识别模块进行优化,将原始的2 个全连接层扩展为3个卷积块和一个全连接层。具体的结构如图5所示。将A模块提取的特征从源域和目标域数据中转移到域识别模块B中,可以对提取的特征源d∈[0,1]进行识别,B模块的体系结构和每一层的输出情况如表4所示,其中Batch Norm表示批量归一化层,用于加快原始模型的收敛速度,同时解决梯度爆炸和梯度消失的问题。
表4 模块B体系结构
图5 对抗域模块B
2.1.3 分类器模块C
为了提高特性的维度和丰富性,对模块C进行了优化。输入Softmax 层之前,将从全连接层特征提取的源域数据与1.2 节中4 次提取的时间特征进行拼接作为最终特征分类器的输入,让时间特征和深度学习提取特征的功能更好地结合,丰富特征多样性。具体的结构如图6 所示,表5 为分类器模块C 的结构体系和每一层具体输出情况。
图6 分类器模块C
表5 模块C体系结构
2.2 训练过程
A(·)和C(·)对应的是特征提取器和分类器。在学习阶段,目标是学习特征提取器A 和任务分类器C,以最小化预期的目标损耗,使源域的数据分布与目标域的数据分布保持一致,减少域之间的差异。其中,A、B、C这3个模块对应的网络映射分别为Ga,Gb,Gc;引入联合损失函数E(ω,ω,ω),如式(1)所示。
损失函数主要包括2个部分,即分类损失L(c·,·)和域区分损失L(b·,·)。Lic和Lib表示第i个训练样本中计算出的相应损失函数。其中Lc选用的不是使用传统的交叉熵损失,而是选择Focal损失函数,是因为Focal损失函数在交叉熵的基础上增加一个动态缩放因子[24],以解决类别分类不平衡以及困难的样本难以训练的问题,自动降低简单样本的损失权重,帮助模型集中于训练更加困难的样本,其公式为:
其中,αx为x样本中该类别的权重参数,(1-Px)γ为自由缩放因子,其中γ为可调节参数,用于进一步调节缩放因子,公式部分的-log(Px)项为传统交叉熵损失函数,其中Px定义为:
其中y表示x样本中的类别标签,P∈[0,1]表示模型输出y=1类别的概率。
ωa、ωb、ωc分别为多尺度特征提取模块、域识别模块和分类模块的参数。λ表示2个学习目标之间的权重,di=0 表示第i个样本为源域样本。训练过程如式(4)、式(5)所示。
具体训练步骤如下:
步骤1 保持域识别模块ωb的参数不变,并通过式(4)计算,最大限度地减少域识别模块的损失,更新多尺度特征提取模块ωa的参数,得到域不变特征。这样可以充分获得不变特征,从而可以同时对源域数据和目标域数据进行汇总。最小化分类模块的损失,更新分类模块的参数ωc以获得一个能够准确预测标签的分类器。^ωa,^ωb,^ωc分别为鞍点ωa,ωb,ωc的参数值。
步骤2 修正参数ωa和ωc,并保持它们不变。利用公式(5)最小化域识别模块的损失,更新域识别模块ωb的参数,得到一个能够区分特征源的强鉴别器。
步骤3 重复步骤1 的操作,固定域识别模块的参数ωb不变,通过式(4)训练多尺度特征提取模块A和分类器模块C,利用该训练过程,交替更新参数。
步骤4 最终网络保持了动态平衡,达到预定的迭代次数后,得到最优值,并保存最优模型。将新的心拍样本输入到保存的最优模型中,得到最终的分类结果。
训练过程如图7所示。
图7 对抗域自适应模型的训练过程
3 实验与结果讨论
3.1 评估方法
为全面评估本研究方法心律失常分类的结果,将混淆矩阵用于评估心拍分类结果,多类分类区别于二分类任务,其TP 表示被正确检测到的心跳数,TN 表示未被正确检测到的心跳数,FP 包含被归为该类别的其他类别的心跳数,FN 包含这个类别被分类为其他类别的心跳次数。研究包括3 个评价指标:准确性Acc(accuracy)、敏感性Sen(sensitivity)和阳性预测值PPV(positive predictive value),指标定义为:
3.2 数据划分和模型调参
根据1.1 节所述的数据划分情况,将数据集I 中所有的心跳样本定义为源域数据,将数据集Ⅱ中每条记录前5 min 的心跳样本定义为目标域数据,训练数据包括源域数据及其标签和目标域数据,测试数据为整个数据集Ⅱ。训练完成后保存最优模型,整个模型使用随机梯度下降优化器进行训练,初始参数为默认参数,为防止训练集特征学得过快而丢失重要特征的学习,初始学习率设置为0.005,并设置学习率衰减因子lr_decay 为0.95,Dropout 均设置为0.3 以防止过拟合,设置为0.25 来平衡2 个分类器参数更新的权重,一轮训练的批次数batch_size 设置为64,最大迭代轮数Epochs 设置为256 轮,并设定10 轮不出现损失下降则训练停止的机制,网络通过Pytorch 高级神经网络应用编程接口(API)实现。
3.3 实验结果
3.3.1 训练情况
本文算法的训练结果如图8 所示,为模型训练和测试过程中Loss曲线和精度曲线的变化情况。从图8(a)中可以看出,测试集和训练集Loss 损失函数曲线趋势相似,虽出现波动但总体稳定下降。图8(b)训练集和测试集的精度逐渐向预期方向提高,测试集的精度没有急剧或大幅度下降,说明不存在过拟合情况。
图8 模型训练过程
3.3.2 预测效果
本文算法的测试集预测情况(数据集Ⅱ)如表6混淆矩阵所示,不同种类心拍的具体评价指标见表7。本文模型在测试集样本中的总体准确率、灵敏度、阳性预测值分别为98.8%、97.9%、98.1%,其中心室震颤的准确率在各类型中最高,达到99.8%,步速跳动标签在灵敏度指标中达到最高值99.4%,正常窦性心律标签在阳性预测值中达到最好效果99.6%。该研究发现心拍数据量少于1000 例的类别如室性逃逸和心室颤振,即使使用数据增广的方法,也无法避免对该类别因训练不足导致阳性预测值较低的结果。对于室性逃逸类阳性预测值偏低的心拍类型,需要在后续研究中进一步提取有效、重要的心电特征来提升阳性预测值。
表6 预测结果与真实结果的混淆矩阵
表7 测试效果指标评估
3.3.3 实验模型比较
为了验证多尺度特征提取器和添加时间特征对实验结果的影响,采用以下3个模型进行对比实验:
1)模型X。采用3个卷积层、3个池化层和2个完全连接层来构建,卷积层的核大小均设置为3,不采用多尺度和4 个时间特征等提取特征的方法,实现心律失常8分类。
2)模型X+Time4。将1.2 节提出的4 个时间特征添加到模型A的最后一层全连接层中,将拼接后的整体特征输入Softmax分类器进行分类。
3)本文提出的模型(模块A+B+C)。在实验过程中,采用2 组并行的多尺度和4 个时间特征等提取特征的方法。
以上3 个模型均使用相同的输入数据,分类结果对比如表8所示。
从表8 中可以看出,第一组是采用单尺度提取特征的模型,而第二组增加了时间特征,时间特征是根据专业知识手工获取的,类似于向分类器中添加专家知识,丰富特征信息,并且提升了特征的可解释性,因此有效地提高了分类的准确性。
对比均添加时间特征信息的第二组单尺度模型和第三组多尺度模型,可以看出多尺度特征提取模块可以有效提高分类精度。这是因为复杂的网络结构可以更全面地提取信号的整体特征,而不是局部特征,也可以更详细地提取部分容易被忽视的特征。由表8可知,第三组训练结果的总体准确率为98.8%,高于前2 组,并且第三组的8 种心律失常类别的阳性预测值和灵敏度均显著高于第二组。综上所述,本文提出的多尺度对抗域自适应模型的分类性能优于其他2种模型。
表8 不同模型的效果比较/%
3.3.4 与其他研究方法对比
为了使模型的有效性和泛化能力得到验证,将本文提出的模型与先前国内外ECG 心电分类研究的不同方法进行对比。为了使研究具有可比性,故针对仅使用MIT-BIH 数据库的研究进行比较,研究还需满足心律失常标签分为5~9类,且采用准确率、灵敏度、阳性预测值进行比较评估(若原研究中没有上述指标,则通过研究提供的混淆矩阵计算获得,具体计算方法参考3.1节评估方法部分为依据进行)。
表9列出了本文与国内外研究的比较情况。可以看出,在统一使用MIT-BIH国际标准心律失常数据库下,本文提出的方法在准确性、灵敏度、阳性预测值等评估指标中效果均优于表9中的其他方法。为了确保研究时间上具有可比性,表中与近5年的3篇研究文献进行比较,研究表明本文对抗域适应系统的识别性能(灵敏度和阳性预测值)高于其他方法,尤其是阳性预测值达到了98.1%,其正确识别并归为某类阳性心律失常事件的能力远高于其他方法,同时其具备较高的精度和敏感性,验证了本模型具备较好的泛化能力。
表9 本文方法与其他方法比较
4 结束语
本文针对心电信号训练数据标注较少、患者个体变异导致的数据分布差异以及特征提取单一等问题,提出了一种基于对抗域自适应的多尺度心拍分类模型,该模型是对抗域自适应理论应用于心电图信号识别分类领域的一次创新和实践。研究分别对该方法的A、B 和C 这3 个模块进行了优化,实验结果表明,采用对抗域自适应学习方法有效地提升了模型的性能和训练效果,并且在模块A中设计多尺度特征提取器,在模块C 中融合时间特征和振幅特征,可以提高提取的数据特征的丰富性和多样性。最终在5.2万条心电信号的测试集的心律失常8 分类任务中取得了98.8%的准确率、97.9%的灵敏度、阳性预测值为98.1%的效果,通过实验验证了改进的对抗域自适应学习分类模型的有效性。
多尺度提取特征尺度数量的扩充和调优、心拍分割信号点长度的选择以及分类器与时间特征、生理信息特征的融合方式等都有待进一步研究,今后将不断以既有思想和方法为基础,更好地利用对抗域自适应模型的优势,来提高心电信号识别与分类任务的精度。