基于知识蒸馏的房颤信号提取方法
2023-10-08甘兆明林家荣杨其宇
甘兆明 林家荣 杨其宇
摘要:针对基于双时域卷积网络的房颤信号提取网络模型存在的参数量大、运算资源要求高和实时性差等问题,提出基于知识蒸馏的房颤信号提取方法。该方法的教师网络和学生网络分别采用3层、1层的基于时域卷积网络(TCN),维度分别为256和32。实验结果表明,采用知识蒸馏的方法可以提高学生网络的性能,且用蒸馏后的房颤信号提取的学生网络相比于教师网络,其网络模型更小、运算资源要求更低、实时性更高,为部署到资源有限的嵌入式设备提供了理论依据。
关键词:知识蒸馏;房颤信号提取;时域卷积网络;教师网络;学生网络
中图分类号:R318;TN911.7文献标志码:A 文章编号:1674-2605(2023)03-0005-06
DOI:10.3969/j.issn.1674-2605.2023.03.005
Method for Extracting Atrial Fibrillation Signals Based on Knowledge Distillation
GANZhaoming LIN JiarongYANG Qiyu
(Guangdong University ofTechnology, Guangzhou 510006, China)
Abstract: A knowledge distillation based method for extracting atrial fibrillation signals is proposed to address the problems of large parameter quantities, high computational resource requirements, and poor real-time performance in the dual time domain convolutional network based atrial fibrillation signal extraction network model. The teacher network and student network of this method adopt a 3-layer and 1-layer time-domain convolutional network (TCN) with dimensions of 256 and 32, respectively. The experimental results show that using knowledge distillation can improve the performance of student networks, and compared to teacher networks, student networks extracted from distilled atrial fibrillation signals have smaller network models, lower computational resource requirements, and higher real-time performance, providing a theoretical basis for deploying to embedded devices with limited resources.
Keywords: knowledge distillation; extraction of atrial fibrillation signals; time domain convolutional network; teacher network; student network
0 引言
房颤是常见的心血管疾病[1],它会引发患者的并发症,如脑卒中、心衰、心肌梗死或老年痴呆等。然而,目前从患者体表采集的心电图(electrocardiogram, ECG)信号为心房和心室的混叠信号,即房颤信号与心室信号在时域和频域上都存在混叠[2]。因此,对房颤信号的实时提取具有重要的现实意义。
房顫信号的提取方法分为多导联方法和单导联方法。其中,多导联方法包括盲源分离算法[3]、时空消除算法[4]和神经网络算法[5]等,该方法需要连接多个电极,存在较大的局限性;单导联方法包括加权平均模板对消法[6]、扩展卡尔曼滤波法[7]等,该方法存在对心室形态变异性敏感、依赖高精度R峰检测算法、需要人工调参等缺点。LU等[8]利用全卷积时域音频分离网络(fully convolutional time domain audio separa-tion network,Conv-Tasnet),提出了双路时间卷积房颤信号提取网络(dual temporal convolution f-wave extraction network,DT-FENet),将Conv-Tasnet的单路编解码结构扩展为双路。DT-FENet具有较高的房颤信号提取精度,但网络模型参数量大[8]、运算资源要求高、实时性差,难以在资源受限的嵌入式设备中部署。
知识蒸馏作为一种深度神经网络的压缩方法,在行人检测、语义分割、图像超采样等应用场景都取得了良好的效果。为解决DT-FENet对运算资源要求高的问题,本文提出一种基于知识蒸馏的房颤信号提取方法,利用知识蒸馏对DT-FENet模型进行压缩。
1 相关理论基础
1.1 房颤心电信号模型
式中: 为心室信号, 为房颤信号, 为采样点数。
1.2 DT-FENet方法
DT-FENet方法是通过两组编解码器分别进行心室信号和房颤信号的编解码,以此进行心室信号和房颤信号的特征映射。DT-FENet的网络结构如图1所示。
将房颤患者的ECG信号x输入到DT-FENet后,由房颤信号编码器和心室信号编码器分别输出房颤成分编码 和心室成分编码 ; 和 分别送到估计器和解码器;信息交互模块利用房颤信号和心室信号的相关性,优化估计器输出的自注意力编码 和 ;在 上应用 、在 上应用 ,可分别得到带有自注意力特征的"Z" ?_"AA" 和Z ?_"VA" ;房颤信号解码器和心室信号解码器分别将高维的Z ?_"AA" 和Z ?_"VA" 解码,得到一维的房颤信号Z ?_"AA" 和心室信号Z ?_"VA" 。
1.3 知识蒸馏介绍
HINTON等[9]提出的知识蒸馏方法采用了“教师”和“学生”的概念,又被称为“教师-学生结构”。知识蒸馏方法中重量级的模型是结构复杂、规模大、拟合能力强的教师模型,且通常已被预先训练好;轻量化的学生模型模仿教师模型,吸收从教师模型中提炼出来的知识,具备更好的性能。ROMERO等[10]在Hinton等提出的知识蒸馏方法的基础上,提出了基于中间层的知识蒸馏。在中间层的知识蒸馏中,学生模型的学习目标是使自身的特征空间尽量靠近教师模型的特征空间,其损失函数为
式中:"H" 为交叉熵、均方误差等的评价指标, 和 分别为真实标签和模型输出, 和 分别为教师模型和学生模型的中间层特征, 和 为对应中间层的特征映射函数,D为L1损失、L2损失等距离函数。
2基于知识蒸馏的房颤信号提取方法
2.1 框架设计
基于知识蒸馏的房颤信号提取方法以教师-学生结构为基本框架,以DT-FENet为教师网络,以缩小尺寸的DT-FENet为学生网络,对网络的中间层进行知识蒸馏,其框架如图2所示。
教师网络输出的编码特征、心房掩码、心室掩码用于指导学生网络,学生网络输出的f波和QRST波与数据集中的真实标签进行比较。学生网络在教师网络标签和真实标签的监督下进行训练。
2.2 知识蒸馏方法
本文采用的知识蒸馏方法包括FitNet方法和AT方法。FitNet方法的教师模型和学生模型之间使用一维卷积进行特征映射,即将维数较低的学生模型中间层特征,通过单层的一维卷积网络映射到与教师模型中间层一致的维数,其损失函数为
式中:非线性函数r为一维卷积网络,C为教师模型中间层的通道数,n为样本长度。
AT方法的教师模型和学生模型都加入了注意力机制,作为教师模型和学生模型的特征转换。距离函数选择均方误差(mean square error,MSE),其损失函数为
式中: 为教师模型的中间层通道数, 为学生模型的中间层通道数, 和 分别为教师模型和学生模型的注意力图,n为样本长度。
AT方法的注意力机制将中间层各通道的值求和并进行L2范数归一化,使该层神经元整体的激活程度成为权重,令样本的不同区域对应不同程度的注意力,并将该注意力信息通过知识蒸馏传递给学生网络。
3实验和结果分析
3.1 数据集
本实验的训练集和验证集均源自Castilla-La Manch数据库(CLMDB)[11],测试集来源于PhysioNet的MIT-BIH Atrial Fibrillation数据库(AFDB)[12]和PhysioNet/Computing in Cardiology Challenge 2017数据库(Challenge 2017)[13]。CLMDB中的样本被随机划分为训练集(70%,168条)和验证集(30%,72条),并被切割成时长为30 s的片段。为了抑制测试集样本的基线漂移,采用截止频率为0.01 Hz的一阶巴特沃滤波器对AFDB和Challenge 2017中的样本进行高通滤波。
3.2 损失函数
LU等在提出DT-FENet方法时,选择信噪比(signal noise ratio, SNR)为训练网络时的损失函数,其定义为
式中: 为训练集的批尺寸, 和 分别为从每条训练样本提取的房颤信号和心室信号的信噪比, 和 分别为训练样本的真实房颤信号和模型提取的房颤信号, 和 分别为训练样本的真实心室信号和模型提取的心室信号, = 0.05是心房通道和心室通道之间的折衷系数。
值得注意的是,房颤信号提取模型的训练目标是最大化信噪比值,而神经网络的训练过程只能使目标值单调递减,所以在信噪比前加负号,将模型的训练方向设置为SNR值的最大化。这一设计是损失函数的值恒小于0且其绝对值逐渐变大。然而,知识蒸馏方法的距离函数采用了MSE,在进行知识蒸馏时的优化目标是最小化MSE,即损失函数的值恒大于0且其绝对值逐渐变小。因此,为了正常训练损失和知识蒸馏损失有一致的趋势,有必要对正常训练损失进行调整。考虑到归一化均方误差(normalized mean squared error, NMSE)和SNR的计算公式在形式上有倒数的关系,本文将正常训练损失由SNR改为NMSE,其定义为
综上所述,知识蒸馏的损失函数由正常训练损失L_"train" 和知识蒸馏损失L_"KD" 两部分组成。采用FitNet方法时,L_"KD" 对应公式(15);采用AT方法时,L_"KD" 对应公式(16):
对于FitNet方法, 和 分别是编码器输出的心房编码蒸馏特征和心室编码蒸馏特征, 和 分别是TCN网络输出的心房编码蒸馏特征和心室编码蒸馏特征。AT方法类似,这里不再重复叙述。
3.3评价指标
因为测试集AFDB和Challenge2017都来源于临床的ECG监测,仅包含心房和心室的混合信号,无法提供房颤信号作为真实值,所以在測试过程中无法直接使用模型提取的波形与真实值的拟合误差来表示模型的提取精度。本文采用频谱集中度(spectral concentration, SC)[14]替代SNR和NMSE来作为评价指标,房颤信号的频谱较窄,绝大部分频谱能量集中在3~12 Hz的区间内,若提取的房颤信号频谱集中度较高,说明房颤信号的失真小,即算法的提取精度较高。SC的计算公式为
式中: 为样本的采样率, 为房颤信号的功率谱。
3.4实验设置
教师模型和学生模型的参数设置如表1所示。
3.5实验结果和分析
为验证知识蒸馏对房颤信号提取的有效性,对比教师网络、FitNet学生网络和AT学生网络的房颤信号提取效果,其在验证集和测试集的提取精度分别如表2、表3所示。
由表2可以看出,SC与SNR、NMSE具有较好的一致性,从一定程度上证明了SC是一种有效的房颤信号提取精度评估指标,适用于数据集未提供实值时的房颤信号提取精度的评估。
由表3可以看出,使用FitNet方法知识蒸馏所得的学生网络,在AFDB上的房颤信号提取精度优于AT方法网络和教师网络的提取精度。在Challenge2017上,使用FitNet方法和AT方法的知识蒸馏所得学生网络的SC与正常训练所得学生网络的SC相当。
教师网络、FitNet学生网络和AT学生网络模型的资源占用情况如表4所示。
由表4可知,学生模型的总参数量和总浮点运算数约是教师模型的三分之一。
结合表3可知,学生网络用三分之一的资源开销在Challenge2017测试集上取得了和教师网络相当的效果,证明了基于知识蒸馏的房颤信号提取方法的有效性。
4 结论
针对现有的房颤信号提取网络DT-FENet存在规模大、消耗资源多的问题,本文通过知识蒸馏的方法对DT-FENet网络模型进行压缩,学生网络用三分之一的资源开销在Challenge2017测试集上取得了和教师网络相当的效果,解决了DT-FENet难以部署在嵌入式设备的问题。本文只利用知识蒸馏对DT-FENet进行压缩,还可以利用剪枝和量化等方法进行下一步研究。
参考文献
[1] 国家心血管病中心.中国心血管健康与疾病报告2020[J].心肺血管病杂志,2021,40(10):5.
[2] LIN C H. Frequency-domain features for ECG beat discrimina-tion using grey relational analysis-based classifier[J]. Compu-ters& Mathematics with Applications, 2008,55(4):680-690.
[3] CASTELLS F, RIETA J J, MILLET J, et al. Spatiotemporal blind source separation approach to atrial activity estimation in atrial tachyarrhythmias[J]. IEEE Transactions on Biomedical Engineering, 2005,52(2):258-267.
[4] LEMAY M, JACQUEMET V, FORCLAZ A, et al. Spatiotem-poral QRST cancellation method using separate QRS and T-waves templates[C]//Computers in Cardiology, 2005. IEEE, 2005:611-614.
[5] V?SQUEZ C, HERN?NDEZ A, MORA F, et al. Atrial activity enhancement by Wiener filtering using an artificial neural network[J]. IEEE Transactions on Biomedical Engineering, 2001,48(8):940-944.
[6] DAI H, JIANG S, LI Y. Atrial activity extraction from single lead ECG recordings: Evaluation of two novel methods[J]. Computers in Biology and Medicine, 2013,43(3):176-183.
[7] ROONIZI E K, SASSI R. An extended Bayesian framework for atrial and ventricular activity separation in atrial fibrillation[J]. IEEE Journal of Biomedical and Health Informatics, 2016, 21(6):1573-1580.
[8] LU J, LUO J, XIE Z, et al. Dual temporal convolutional network for single-lead fibrillation waveform extraction[J]. Neural Com-puting and Applications, 2021,33(22):15281-15292.
[9] HINTON G, VINYALS O, DEAN J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015, 2(7).
[10] ROMERO A, BALLAS N, KAHOU S E, et al. Fitnets: Hints for thin deep nets[J]. arXiv preprint arXiv:1412.6550, 2014.
[11] STRIDH M, SORNMO L. Spatiotemporal QRST cancellation techniques for analysis of atrial fibrillation[J]. IEEE Transac-tions on Biomedical Engineering, 2001,48(1):105-111.
[12] ALCARAZ R, S?RNMO L, RIETA J J. Reference database and performance evaluation of methods for extraction of atrial fibrillatory waves in the ECG[J]. Physiological Measurement, 2019,40(7):075011.
[13] CLIFFORD G D, LIU C, MOODY B, et al. AF classification from a short single lead ECG recording: The PhysioNet/ com-puting in cardiology challenge 2017[C]//2017 Computing in Cardiology (CinC). IEEE, 2017:1-4.
[14] DAI H, JIANG S, LI Y. Atrial activity extraction from single lead ECG recordings: Evaluation of two novel methods[J]. Computers in Biology and Medicine, 2013,43(3):176-183.
作者簡介:
甘兆明,男,1995年生,硕士研究生,主要研究方向:模式识别、机器学习、生物信号处理。E-mail: 1803158832@qq.com
林家荣,男,1997年生,硕士研究生,主要研究方向:模式识别、机器学习、生物信号处理。E-mail:1440645304@qq.com