基于压缩卷积神经网络的心律不齐分类方法①
2023-10-28韩传奇崔莉
韩传奇 崔莉
(*中国科学院计算技术研究所 北京 100190)
(**中国科学院大学 北京 100190)
0 引言
根据世界卫生组织的统计数据,心血管疾病(cardiovascular diseases,CVDs)每年可导致约1790万人死亡,是世界范围内致死率第一的疾病[1]。2017 年,我国每5 例死亡病例就有2 例死于心血管疾病[2]。严峻的心血管疾病形势给患者个人、家庭和国家都带来了沉重的精神压力和经济负担。而根据相关调查,大约有90%的心血管疾病病例可以通过早期监测配合健康的生活方式得到预防或改善[3]。
心律不齐是一种较为常见的心血管疾病,主要由心脏组织和活动的变化而引起,严重时会导致昏厥、心力衰竭甚至猝死。近年来,以Apple Watch、Huawei Watch 为代表的消费级可穿戴设备通过搭载单导联心电信号(electrocardiogram,ECG)传感器,实现了以房颤为代表的心律不齐筛查功能。可穿戴设备凭借低成本、便携的特点,非常适合人们日常使用以实现主动式健康管理的目的,得到了人们的认可。然而,现有的产品在ECG 分析过程中需要云端的支持。具体来说,可穿戴设备在使用ECG 分析功能前,需要先通过App 与移动设备建立连接。测量时,由可穿戴设备搭载的传感器采集心电信号,之后通过蓝牙等通讯方式将数据传输到移动端,再借由移动端网络上传到云端。该数据经由医生诊断或云端部署的算法推理得到结论后,将信息返回给用户。上述过程将算力集中在云端服务器,需要可穿戴设备与移动设备传送大量的数据,增加了等待时延和用户隐私泄露风险。因此,设计一种可部署在移动端甚至是可穿戴设备上的模型,实现人工智能物联网“端-边-云”协同,对于提升心脏健康辅助诊断实用水平具有重要的意义。
在学术界,利用人工智能技术对单导联ECG 数据进行心律不齐的辅助诊断已有较多的研究,主要可分为基于特征工程的机器学习方法和基于端到端的深度学习方法。采用特征工程的机器学习方法首先需要对数据进行预处理,之后研究人员根据医学先验知识和自身经验设计并优化特征集,并将这些特征投入到诸如线性判别器(linear discriminants,LD)、支持向量机(support vector machine,SVM)等分类器中。此类方法具有模型体积小以及可解释性强等优点。然而特征工程通常需要耗费大量的时间,对研究人员的医学知识储备亦提出较高的要求。而在大数据时代下,基于SVM 等传统机器学习的算法在准确率方面亦容易遇到瓶颈[4]。相比之下,深度学习模型具有端到端的训练推理方式,可以直接以经过预处理甚至原始的心电信号为输入,自动完成特征提取及分类。但为了抽象出包含复杂病理信息的特征,神经网络结构通常较为复杂,参与训练的参数量规模可以达到百万级甚至千万级以上,对部署环境的算力提出了严苛的要求,不利于模型向移动端甚至边缘端迁移和部署。此外,模型参数更新依赖于数据驱动,因此基于深度学习的方法在训练时对数据集的质量有较高的要求。然而受采集难度等因素影响,目前业界公开的心律不齐数据库通常规模较小且不同类别的样本分布严重不均衡。以业界著名MIT-BIH 心律不齐数据库[5]为例,该数据库包含了47 位个体的ECG 数据,仅有约30%的心拍类别为异常类别。在这种情况下,不同个体间生理存在的差异会对模型训练产生较大的干扰[6],这些因素共同导致深度学习模型在面对新的个体数据时无法对部分异常类别样本开展实用化的筛查。
为解决上述问题,本文在现有经典网络的基础上通过分析卷积神经网络特征图维度及心电样本特点,提出了一种压缩方案,使得模型能够在大幅减少参数量的前提下保持较优的分类能力。同时通过将类别先验概率引入损失函数,使用logit adjusted loss[7]训练模型,增强了模型在识别异常类别上的性能。
论文的其余部分组织如下。第1 节介绍国内外相关研究工作。第2 节给出本文所采用的心电分类算法框架,同时介绍了样本预处理方法。第3 节介绍本文提出的压缩卷积神经网络和采用的logit adjusted loss 训练模型细节。第4 节给出实验介绍和结果分析。第5 节对论文进行总结。
1 相关研究
在基于特征工程的传统机器学习方法中,文献[8]基于医学经验设计了8 组时间域特征集,其中每组特征集包含RR 间期(RR-intervals)等20 余个特征。之后研究人员利用线性判别器按照美国医疗仪器促进协会(The Association for the Advancement of Medical Instrumentation,AAMI)标准对5 种心律不齐做分类判别,其实验过程和结果在业界树立了坚实的基准。文献[9]则在此基础上通过引入小波变换后的频域特征提升分类效果。文献[10]亦采用了线性判别器为分类器,所选用的特征除了归一化的RR 间期外,还有利用小波分析和线性预测建模(linear prediction modeling)从原始波形中提取特征。文献[11]和[12]分别利用粒子群算法和遗传算法优化特征选择,之后以SVM 为分类器进行心电异常识别。不同于上述采用单分类器的方法,文献[13]先将PT 段波形作为输入,利用SVM 中进行心律不齐初级分类;之后根据基于RR 间期的先验规则进行二级分类,并在仿真实验中取得了良好的泛化表现。
近年来,深度学习已在计算机视觉、自然语言处理等领域取得了巨大的成就。凭借深度学习强大的数据挖掘能力,利用其进行心律不齐分类也得到大量科研人员的研究。文献[14]通过设计33 层的卷积神经网络(convolutional neural network,CNN)取得了专家级精度的分类算法。文献[15]利用自动编码器(AutoEncoder)对原始信号进行降维,并通过长短时记忆网络(long-short term memory,LSTM)捕捉时序依赖进行分类。文献[16]则在单层LSTM 基础上叠加了一层反向网络,以此捕捉双向时序关联信息,并通过引入注意力机制进一步优化特征权重。最近话题热度很高的Transformer 模型[17],在心电分类任务上也得到了相关应用[18-19]。这些方法在分类精度方法取得了比传统机器学习方法更优的评价指标,但突出的问题是模型体积过大,如文献[14]的可训练参数量达到千万级,为训练如此庞大的模型,其采用的私有数据集包含有53 549 位个体的91 232条ECG 数据。而基于规模较小的公开数据集的方法中,通常采用SMOTE(synthetic minority oversampling technique)[20]等生成式数据增强方法对数量较少的异常样本进行扩充,从而增强模型对异常类别的识别能力。然而这种方法增加了模型的训练负担和周期。此外对于心律不齐分类任务来说,心电图微小的差异都可能对应不同的诊断结果,因此并非所有的生成数据均能保证有效可靠。
2 心律不齐分类方法
2.1 心律不齐分类框架
心律不齐分类任务的目标是,给定心电图样本集X作为输入,通过模型推理得到标签集Y′,使其与真实标签集Y的误差尽可能小。为此,本文提出的心电分类通用架构如图1 所示,其主要分为3 个部分,分别是数据预处理、模型构建与心律不齐分类。数据预处理则包含样本去噪、样本切片、质量评估和样本归一化等。
图1 心律不齐分类任务框架图
2.2 数据预处理
心电信号在采集过程中可能受到环境干扰或动作扰动的影响而带有噪音。为获取高质量心电信号,本文采用的去噪策略如下:首先,分别以200 ms和600 ms 为窗口宽度进行中值滤波,获取心电图基线;之后通过做差去除基线飘移;接下来,以coif5 为母小波,利用离散小波变换(discrete wavelet transform,DWT)的方法去除高频噪音。经过上述处理,心电信号已具备较高的信噪比,初步满足后续分析需求。
在得到高质量数据后,需要将数据切分成固定长度的样本作为模型的输入。如图2 所示,正常的心拍波形图包含5 个关键波群,分别为P 波、QRS 波群和T 波。其中P 波反映心房的除极活动,QRS 波群和T波则分别反映心室的除极和复极活动。为保证固定长度的样本包含上述波群,本文首先利用Pan-Tompkins 方法[21]定位心电图中最明显的R 波,之后以R 波为基准,取其前0.28 s 和后0.45 s 的心电信号作为波形样本。经过切片后的样本为单心拍的ECG 波形数据,为了保留对心律不齐分类有重要作用的多心拍时序关联信息,在切片过程中同时记录下当前心拍与前后相邻心拍间的RR 间隔,记为RR_pre 和RR_post,并据此计算出个体平均心拍间隔RR_avg。以360 Hz 采样率为例,经过上述处理后获得了由260 个采样点组成的波形数据和3 个时间间隔数据组成的心拍样本。
图2 正常心拍的ECG 波形图
在送入模型前,样本还需要经过质量评估。未通过评估的样本将被视为噪音数据而被丢弃。在本文中,当RR_pre 或RR_post 大于2 s 时,会被认为当前心拍样本在采集过程中出现了电极片脱落等意外事件,进而被标记为噪音数据。每条通过质量评估的ECG 信号x,其波形幅值将进一步经过minmax 标准化处理,使其范围控制在[0,1]内:
标准化处理后的心电信号为xn,利用xn进行训练有主题加快模型的收敛速度。
3 压缩卷积神经网络分类模型
3.1 基础模型构建
完整的卷积神经网络通常包含卷积层、池化层和全连接层。其中在卷积层中包含有多个卷积核(kernel)与该层输入进行卷积运算。在图像卷积中,浅层的卷积核通常起到检测边缘等局部信息的作用。而伴随着网络的加深,卷积核提取到的特征也越来越丰富,表达能力也更强。
VGG16 网络是在2014 年由牛津大学和Google DeepMind 公司的研究人员合作研发的深度网络模型[22]。该网络包含13 个卷积层、5 个池化层和3 个全连接层。VGG16 在ILSVRC 2014 比赛中大放异彩,取得了分类项目第2 名、定位项目第1 名的优异成绩。VGG 网络结构简洁、泛化性能好,至今仍然在许多研究中担任骨干网络(backbone)。本文以该网络为基础,设计心律不齐分类模型。
由于VGG16 在设计时的初衷是针对计算机视觉任务,其输入为3×W×H的红绿蓝(red green blue,RGB)图像。而在心电分类任务中,其输入为维度1×W的时序信号,因此本文将该网络中的二维卷积及池化操作更改为一维运算并命名为1d-VGG。由于波形数据通过卷积网络后将映射为高维度的特征图(feature map),为取得相同维度的时间域特征,RR_pre,RR_post 及RR_avg 亦将通过2层卷积网络做嵌入升维(embedding),并同波形特征图拼接后作为全连接网络的输入。上述网络的结构如图3(a)所示。
图3 1d-VGG 和cpr-VGG 网络结构图
经过计算,1d-VGG 网络的训练参数量为28.4 M,规模较为庞大。需要进一步压缩以适配移动端部署条件。
3.2 模型压缩
卷积神经网络随着层数的加深,其特征图感受野也在扩大。以图3(a)所示的结构为例,若以连续多个卷积层和1 个maxpool 层为1 个卷积模块(block),则图中包含的5 个卷积模块输出的特征感受野如表1 所示。对于浅层特征,其感受野相对较小,关注的是ECG 局部细节信息;而对于深层特征,其侧重捕捉心电图局部特征的上下文关联。
表1 1d-VGG 网络卷积模块的特征感受野
在心律不齐分类中,异常的病症会导致ECG 关键波群形态有明显变化。比如房颤(atrial fibrillation,AF)发生时通常伴随着P 波消失和QRS 波肥大[23]。这些异常变化对于准确识别异常病症至关重要。正常人的P 波以及QRS 波的时长通常不超过0.11 s,以360 Hz 的采样率计算,上述波群在样本中约占40 个采样点。根据表1 不难看出,浅层的卷积层可以起到异常波群检测的作用。而对于深层卷积层提取的特征,它们的维度更高,表达能力更强,但分辨率较低。结合心律不齐分类的诊断依据,其包含的冗余信息也相应增加。这些信息对心电分类助力有限,反而可能使模型过分拘泥于病人的一些个体差异,降低泛化性能。事实上,文献[24-26]在图像分类任务中亦对VGG 网络深层特征图包含较多的冗余信息有近似的结论。
为此,在1d-VGG 的基础上对深层网络的参数进行调整(见图3(b)),具体来讲,保留1d-VGG 骨干网络中前2 个卷积模块中的设定。伴随着深度的加深,卷积核数量的压缩率也相应增加(见表2),以此促使网络在深层特征选择上关注更多有用的信息。RR 时序特征嵌入网络和全连接网络参数也根据波形特征图维度的变化做相应的调整。经过上述操作,cpr-VGG 网络的训练参数量压缩为0.5 M,相比压缩前降低了98.2%。
表2 cpr-VGG 网络卷积核压缩率设置
3.3 模型训练
分类任务中,真实类别为y的样本x在经过模型f(·) 的非线性拟合后,对每一个类别输出一个logit,即fy′(x)。这些logit 通过Softmax 函数,映射为样本x对应不同类别L的概率。对于其真实类别y,输出概率为
通过对数据集中每一个样本进行上述映射,可以得到一个拟合分布Q,其与真实分布P之间的差异可通过KL 散度(Kullback-Leibler divergence)计算:
其中,H(P(x)) 为真实分布的熵,其数值在一个数据集上为定值。因此,衡量上述差异的关键在于式(3)的后一项。将式(2)代入该项中,可得到样本x对应的交叉熵(cross entropy,CE)损失:
在计算得到该损失值后,便可以通过反向传导更新模型参数,使得拟合分布与真实分布的差异逐步缩小。对于数据集D中的多个样本,其在反向传导所用的损失函数值为多个样本结果的平均值:
其中N为样本个数。式(5)为分类任务中最常用的损失函数。
通过优化式(5)训练的模型将取得最小平均误差。然而如前文所述,在实际训练时采用的数据库往往存在着严重的疾病类别不平衡的情况。在这种情况下,将数据集D视为|L| 个类别数据的集合,即D=,其中Dj包含的样本数为nj,则式(5)可以进一步推导为
由式(6)不难看出,对于数据集中数量最多的优势类别(通常为正常类别),其在计算损失函数中占有的权重也最大。因此,采用传统的交叉熵损失函数在类别分布不平衡的条件下,将导致模型更关注优势类别的误差率而容易对数量较少的异常样本造成误诊。
为使模型能公平地对待每一种类别上的误差,引入平衡错误率[27](balanced error rate,BER),其定义如下:
与式(6)相比,式(7)在计算错误率时对每个类别使用了均衡的类别概率。为了得到最小平衡错误率,期望模型f(·) 学习到的数据分布Pbal(y|x)具有均衡的类别概率。为方便说明,将理想的分类器记为f*(·),从而有:
根据贝叶斯公式,式(9)又可推导为
注意到在式(2)中,有P(y|x) ∝exp(fy(x)),将其代入式(10),便可以推出:
由式(11)不难看出,取得最小平衡错误率的理想分类器f*(·) 与当前取得最小平均误差的分类器在logit 上相差一个偏移项lnP(y)。因此只需要将该偏移项与模型logit 相加后代入式(4)表示的传统交叉熵损失函数,得到式(12),便可以通过优化该损失函数训练得到理想分类器。
4 实验及验证
4.1 数据来源
本文实验采用业界知名MIT-BIH(MITDB)心律不齐数据库。该数据库包含48 条时长约为30 min的二导联数据,采样率为360 Hz。按照文献[4]方法,移除4 条含有起搏器因素的数据。为适应可穿戴设备的应用需求,仅利用II 导联数据作为实验分析对象。之后,按照AAMI 标准将MIT-BIH 数据库标签映射为5 大类:非异位搏动(N)、室上性异位搏动(S)、室性异位搏动(V)、混合搏动(F)及未知搏动(Q)。其映射关系如表3 所示。
表3 类别标签映射关系
在划分数据集时,采用更能测试模型泛化性能的患者间范式(inter-patient paradigm),其中训练集和测试集均包含来自22 位不同个体的数据,相应的个体编号及样本类别分布情况见表4。由于F类与Q类的样本数量过少,在后续讨论中本文将聚焦于其他3 个主要类别的效果评估中。
表4 数据集样本分布统计表
4.2 评估指标
本文选用灵敏度(sensitivity,Sen)、精准率(precision,Pre)和F1 值作为衡量模型分类表现的客观指标。灵敏度反映的是模型对正例的查全水平,而精准率反映的是模型对正例的查准水平。由于灵敏度和精准率在多数情况下是一对矛盾的度量指标,因此有F1 值作为权衡二者的综合度量。对于多分类任务,进一步计算了宏F1(macro-F1)。上述指标的计算公式如式(13)所示,其中TP、FP和FN分别代表真阳性、假阳性和假阴性样本个数。
4.3 训练设定
本文方法基于Python 3.7 实现,采用Pytorch 1.7.1作为深度学习框架,在高性能GPU 服务器上进行实验。具体的配置为AMD EPYC7502 32 核处理器,64 GB 内存,RTX3090(24 GB)显卡。
本文实验采用Adam(adaptive moment estimation)[28]优化器,它可以计算每个参数的自适应学习率。在实际应用中Adam 被证实具有收敛速度快、学习效果良好等优点,其默认参数就能解决许多神经网络优化问题,因此本文的初始学习率亦设为默认的0.001。本文的批大小(batch size)设为128,该设定亦被许多相关文献采用[14-15,19]。为衡量模型在训练中的表现,从训练集中随机选取30%数据作为验证集。当验证集损失函数在连续5 轮迭代(epoch)中未出现下降时,模型学习率降为之前的10%。当损失函数连续12 轮未下降时,模型停止训练,输出训练过程中取得最小损失函数的模型参数用于后续测试。具体训练算法如算法1 所示。
4.4 实验结果与分析
本文以未经压缩的1d-VGG 为baseline 模型,同时以压缩后的cpr-VGG 作为对照组,分别按照上一节内容展开训练。二者在测试集的分类混淆矩阵分别如表5 和6 所示。之后,根据式(13)计算相关评价指标,对比结果如表7 所示。为衡量logit adjusted loss 模型的作用,本文进一步开展了以经典交叉熵作为损失函数的消融实验。实验结果亦列于表7 中。从该表可以看出,经过压缩后的cpr-VGG 模型在S 类的F1 值增加0.01,N 类的F1 值下降0.01,而V 类的F1 值下降最为明显。之所以出现如此明显的下降,是因为N 类的灵敏度下降0.02,而由于N 类是主导类别,占据测试集样本总数的90%,其灵敏度的下降将极大影响其他弱势类别的指标。例如,有近1301个N 类样本被误分为V 类,进而削弱了cpr-VGG 在识别该类异常的特异度。然而考虑到cpr-VGG 相比1d-VGG 在参数量上减少了98.2%而宏F1 值仅下降0.03,因此本文认为cpr-VGG 凭借更轻巧的特点更适合赋能可穿戴设备实现本地心律不齐分类功能。
表5 1d-VGG+logit adjusted loss 分类混淆矩阵
表6 cpr-VGG+logit adjusted loss 分类混淆矩阵
表7 消融实验对比结果
为进一步分析cpr-VGG 的分类表现,本文对同样采用MITDB 数据库的典型相关文献进行对比。从表8 可以看出,cpr-VGG 取得了最优的宏F1 值,表明其可以较好地完成多分类心律不齐分类任务。
表8 cpr-VGG 与相关工作的对比结果
在采用交叉熵为损失函数的消融实验中,根据表7 不难看出1d-VGG 和cpr-VGG 的宏F1 值均出现了明显的下降,其主要原因是模型在室上性异位搏动(S)类别上表现较差:灵敏度分别为0.42 和0.43。与此同时,大量的该类样本被误分为主导类别N 类。上述结果印证了深度学习模型在类别不均衡条件下训练可能导致其对数量较少的异常样本误诊率提高,同时也体现出logit adjusted loss 模型对于缓解该问题的突出贡献。需要注意的是,采用交叉熵作为损失函数并未对室性异位搏动类样本(V 类)识别上造成明显的影响。为分析原因,本文将N、S 和V3类样本的平均波形图绘制如图4所示。可以直观地发现,V 类的形状与其他2 类有较为明显的差异。这种差异可以帮助模型在训练时将V 类样本映射到与其他类别边界距离足够大的空间,因此可以降低类别不平衡带来的不利影响。
图4 心律不齐主要类别样本平均波形图
5 结论
具备健康管理功能的可穿戴设备引起了市场的强烈反响。但是受制于资源限制,在移动端甚至边缘端部署深度学习模型对于网络结构设计提出了严苛的要求。此外,数据类别不平衡问题极大地影响着模型在异常类别上的分类性能,对模型的实用性提出严峻的挑战。针对该问题,本文提出了一种经过压缩的cpr-VGG 模型,实现单导联心律不齐分类。相比于1d-VGG 模型基准,新的模型大幅减少参数数量,提升了模型的部署能力。另一方面,为改善模型在类别不平衡数据集上的分类表现,本文以引入类别先验分布的logit adjusted loss 作为损失函数,在不增加训练负担的前提下改善了模型的性能。在实验验证中,cpr-VGG 表现出与未经压缩的1d-VGG 相近的分类性能,与典型同类工作相比也取得了更高的多分类评价指标。未来的研究工作包括设计更精细化的剪枝压缩方法和构建更合理的样本组织形式,以进一步提升模型在心律不齐分类中的综合表现。