一种基于元学习的医疗文本分类模型
2022-12-27赵志桦
赵 楠, 赵志桦
(1.海军军医大学第二附属医院信息中心,上海 200003;2.海角科技政策服务部,浙江 杭州 310000)
医疗文本包含病历文书、诊断报告、病理报告、影像报告等大量的医疗信息,是进行疾病预测、辅助诊疗、药物研发、个性化信息推荐、临床决策支持等的重要文本资源[1]。医疗文本具有复杂的专业术语,常见的传统文本分类模型在处理医疗文本时常常面临文本语义稀疏、训练数据不足等问题[2,3]。
迁移学习可以从大量的先验任务中学习元知识,利用以往的先验知识来指导模型在新的小样本任务中更快地学习[4]。Finn等人[5]提出MAML方法,从少量数据中进行训练从而得到较好的分类结果,Xiang等人[6]在2018年将MAML迁移到文本分类领域,并且根据不同的词在句子中的重要性加入注意力机制赋予词语不同的权重,提出ATAML模型。为了加强模型对新的医疗文本数据的适应性,文中基于领域自适应性使用已经训练学习过的某个或多个领域上的知识快速学习适应另一个新领域,利用两个相互竞争的神经网络,分别扮演领域识别者和元知识生成者的角色,提出一种基于注意力机制的领域自适应元学习模型ADAML(Attentive Domain Adaptation Meta-learning,ADAML),最后使用岭回归获得文本分类的最终结果[7]。
1 相关工作
1.1 小样本文本分类
小样本文本分类是一项小样本学习在文本分类问题上的具体任务,解决一些现实任务中样本数据不足或者难以获得大量高质量标注数据的问题,比如在医疗领域,一个数据样本可能就是一个临床试验或者一个真实的患者信息,这时基于大数据训练的深度学习不能有效解决此类问题,需要一种具有快速学习能力的模型,可以将之前学习到的知识用于识别新的数据集[8,9]。小样本文本分类通常考虑的是N-wayK-shot分类,这种分类任务训练集Dtrain一共包含I=KN的样本,其中N表示类别个数,K表示每个类别中的样本个数。
1.2 小样本学习模型
小样本学习模型主要分为三大类:基于模型微调、基于数据增强和基于迁移学习。基于迁移学习的方法细分为基于度量学习、基于元学习和基于神经网络3类[4,10]。元学习(meta-learning)是目前比较前沿的方法,其目的是让模型获得一种学习能力,能够举一反三[11],在源领域上使用大量已有数据进行训练,提高学习性能,达到目标领域上减少对数据集规模的依赖。因此,对于目标任务只有较少训练数据的情况,元学习模型会取得更好的效果,实现跨任务的学习共享[12-14]。
1.3 医疗文本分类方法
传统的医疗文本分类侧重于进行医疗数据的自由分类,一般包括四个步骤:医疗文本预处理、特征选择、医疗文本分类、结果评估[15]。分类的主要目标在于从自然语言中分析语义和总结归纳医疗信息,从而简化医疗信息管理过程。医疗文本大多是无规律非结构化的数据,并且大多是高维稀疏数据集,这也是医疗文本自然语言处理过程中的难点,以往基于深度学习的文本分类模型是对预处理文本数据进行训练学习,获得文本分类结果。
2 基于元学习的医疗文本分类模型
2.1 ADAML模型
2.1.1 注意力生成器
注意力生成器利用大型源领域数据集来识别一般词语的重要性,并且利用小目标域支持集来估计特定类别的词语重要性。根据频繁出现的虚词在文本分类中往往不太可能提供太多信息的现象来降低频繁词的权重并提高稀有词的权重[16]。文中采用文献[16]的方法定义统计量反映特定类别词语的重要性。
(1)
xi是输入的第i个单词,ε=10-3,P(xi)是xi在源领域的重要性。支持集中有区别的词在查询集中也有可能有区别,然后定义如下的统计数据来反应词语的重要性。
t(xi)=H(P(y∣xi))-1
(2)
其中条件似然P(y/xi)是支持集上的极大似然,H(·)是熵算子,t(·)是根据频率分布的加权。在迁移学习过程中注意力机制能够从文本编码序列中检索目标任务,获得特定任务的表示,文中使用双向 LSTM融合输入信息,计算点积来预测单词xi的注意力分数。
(3)
hi是i处的双向LSTM输出,v是可学习的向量。注意力机制通过结合源领域数据集和目标域支持集的分布统计来生成特定类别的注意力,提供词语重要性的归纳偏差。
2.1.2 元知识生成器和领域判别器
领域自适应性需要混淆源领域数据集和目标域查询集的样本才能实现有效的领域转移,元知识生成器就是要尽可能使领域判别器无法区分目标域查询集和源领域数据集的样本。元知识生成器对双向LSTM输出采用单层前馈神经网络,使用Softmax函数来获得元知识表征向量kp。
kp=Softmax(ω·hp+b)
(4)
kp是一个n维度的向量,代表句子p中包含的元知识,n表示该句子的长度。
领域判别器通过一个三层前馈神经网络区分样本是来自源领域还是目标域,其输出0 或 1 分别代表样本来自目标域查询集或源域数据集。
2.1.3 交互层和分类器
(5)
ADAML选择岭回归作为分类方法,岭回归是一种有偏估计回归,主要用于共线性数据分析[17]。分类器由每个元任务的目标域支持集从头开始训练,通过适当的正则化减少对小支持集的过度拟合。
(6)
2.1.4 损失函数
在每次训练迭代中ADAML首先固定元知识生成器和领域判别器的参数,通过目标域支持集更新分类器的参数,其中分类器的损失函数如公式(6)所示。 接下来,ADAML固定元知识生成器和分类器的参数,通过目标域查询集和源域数据集更新领域判别器的参数,并使用交叉熵损失作为领域判别器的损失函数。
(7)
其中μ表示领域判别器的参数,m表示目标域查询集或源域数据集的样本数,yd根据其值为0或1表示样本是来自目标域查询集还是源域数据集,k代表元知识向量。最终,ADAML固定领域判别器和分类器的参数,通过目标域查询集和源域数据集更新元知识生成器的参数。元知识生成器的损失函数由两部分组成:第一个是最终分类结果的交叉熵损失;第二个是与领域判别器的损失相反的损失,即混淆判别器。
LG(β)=CELoss(f(W·Gβ(W),y)-LD
(8)
其中β表示元知识生成器的参数,f表示岭回归因子,W表示一个句子中的词向量矩阵,y表示样本的真实标签,LD的定义在公式(7)中。
2.2 算法流程
ADAML模型主要有注意力生成器、元知识生成器、领域判别器、交互层和分类器几部分,目标域支持集S或查询集Q中的类别数为N,S中每个类别的样本数为K,Q中每个类别的样本数为L,源域数据集Φ,其具体算法流程如下伪代码:
算法:ADAML训练数据集Input:训练数据集{Xtrain ,Ytrain };元任务个数T和迭代轮次ep;生成器的参数β;判别器的参数μ;分类器的参数θ.Output: 训练结束后的参数β和μ;随机初始化模型参数β,μ和θ;for each i∈1,ep doY←Λ(Ytrain ,N); for each j∈1,T doS,Q,Φ←∅,∅,∅; for y∈Y do S←S∪Λ(Xtrain {y},K); Q←Q∪Λ(Xtrain {y}S,L); Φ←Φ∪Λ(Xtrain Xtrain {y},L); 将参数S导入模型; 修正参数β,μ,更新参数θ以最小化公式(6); 将参数Q,Φ导入模型; 修正参数β,θ,更新参数μ以最小化判别器损失公式(7); 修正参数μ,θ,更新参数β以最小化生成器损失公式(8);return β,μ;
3 实验与分析
3.1 数据集和实验方法
医疗文本数据采集一个临床科室的电子病历信息,处理生成结构化信息Medical Record保存在JSON文件中。使用三个公开文本分类数据集HuffPost、Amazon、Reuters 和Medical Record,随机抽样构建小样本文本,四个基准数据集如表1所示。
表1 四个基准数据集的统计数据
ADAML模型分为训练和测试两步,首先通过数据集构建不同的元任务,每个元任务包括一个支持集和一个查询集。 训练模型时将采样的元任务输入模型,通过支持集上的损失函数对模型进行微调,并通过查询集上的损失函数对模型参数进行更新。测试时通过元任务中的支持集对模型进行微调,最后在查询集上计算准确度。
3.2 实验与分析
文中使ADAML与MAML[5]、PROTO[18]、ATAML[6]、HATT[19]模型对比处理小样本文本分类问题的准确率。文本分类算法的性能通常采用准确率进行测评,定义如下:
(9)
实验过程中注意力生成器计算词汇注意力分数,元知识生成器使用具有 128 个隐藏单元的双向 LSTM生成元知识表征向量,在领域判别器中两个前馈层的隐藏单元数量分别设置为 256 和 128。元训练期间进行 100 次训练,当验证集上的准确度在 30 次迭代中没有显著变化时,停止此次训练。根据测试结果评估模型性能,见表2。
表2 模型在四个数据集上3way 1shot 和 3way 3shot分类的准确率
模型在四个数据集上均取得了比较好的分类效果,在1shot分类中的平均准确率为 68.85%,在3shot分类中的平均准确率为 82.1%,比模型 ATAML分别提高了3.98%和7.63%,ADAML模型在Medical Record上的 1shot 和 3shot 分类比其他模型平均提高了20% 和16.2%,Medical Record中文本的平均长度比其他长,实证表明ADAML模型更适合包含丰富语义信息的文本分类。不同模型在Medical Record数据集上的不同类别分类结果见表3。
表3 不同模型在Medical Record上不同分类的准确率
4 结 论
医疗文本蕴含丰富的语义信息,有效的分类可以促进医学技术的发展。文中提出一种基于注意力机制的元学习模型,根据词语的重要性赋予不同的权重,并且利用两个对抗性网络增强模型的学习能力,提高小样本文本分类的适应性与准确率。通过对比实验和分析证明了文中模型在公开数据集和医疗文本数据集上的有效性,后续将提高小样本数据集含有噪音的分类性能,并尽可能减少训练过程中的语义损失。