基于任务相似度的增量学习优化方法
2021-04-22张甜郭辉郭静纯
张甜 郭辉 郭静纯
摘要:针对增量学习存在的灾难性遗忘和新任务数据逐步积累问题,提出了基于新旧任务之间相似度的样本重放优化学习方法,相似度越高,重放样本越少。并选择MINIST数据集在卷积神经网络上进行了实验研究,验证了该方法的可行性和有效性。
关键词: 增量学习;灾难性遗忘;样本重放;任务相似度
中图分类号: TP183 文献标识码:A
文章编号:1009-3044(2021)08-0013-03
Abstract: To solve the problem of catastrophic forgetting and gradual accumulation of new task data in incremental learning, an optimal learning approach based on the similarity difference between old and new tasks is proposed. The more similar the tasks are, the less the old samples will be replayed. Moreover, MINIST data set is selected to conduct experimental research on the convolutional neural network, which verifies the feasibility and effectiveness of the method.
Key words:incremental learning; catastrophic forgetting; sample replay; task similarity
随着深度学习的快速发展和在图像、语音等领域的应用,其在单个任务处理方面取得了优异的性能。但当它面对多任务增量学习时,常常产生“灾难性遗忘”现象[1],即学习新任务时会改变原有的网络参数,相应的旧任务记忆就会急剧下降甚至完全消失。
样本重放是缓解灾难性遗忘的主要方法之一,包括两种典型方式:一种通过旧任务的伪样本生成器保留其信息,如深层生成重放[2]和记忆重放GANs[3],不使用旧任务原始数据,但GAN模型训练较复杂;另一种直接选用旧任务的原始数据子集,如内存固定的iCaRl[4]及其改进训练样本不均衡的增量学习文献[5],文献[6]提出一种自动记忆框架,基于样本参数化选取具有代表性的旧样本子集,采用双层优化训练框架。这些方法均未考虑新旧任务之间的相似度差异:相似度越高,网络提取的共有信息越多,则对旧任务的回顾应越少。此外,真实环境下新任务数据通常按照时间顺序流式到达,新数据较少,无法满足上述方法的需要。针对这些问题,本文提出了一种基于任务相似度的增量学习优化方法,根据两者之间相似度差异设置不同比例的训练数据,避免重复训练,减少资源占用,加快训练速度。
1 样本重放增量学习优化方法
增量学习优化方法的实现过程主要分为以下三个阶段:首先,当新任务到达时,用特征提取器提取新旧类特征,进行相似度差异分析;其次,根据相似度差异结果,计算新旧任务不同比例的训练数据增量,构建每批次增量训练数据集;最后,进行增量优化训练,实现符合真实场景下的新任务数据增量训练和任务增量学习。
1.1 符号表示
假设增量学习分为1个初始阶段和N个新任务的增量阶段。在初始阶段使用数据[D0]进行训练得到网络模型[Θ0];在第[i]个增量阶段,若有[s]个旧类[X1,X2,...,Xs],新类[Xi,i∈N],模型状态为[Θi-1],令[Di?j]、[Dij]、[Dj]分别表示第[i]类第[j]个批次的新增样本数据、前[j]个批次新数据和第[j]个批次的新旧训练数据。
1.2 任务相似度分析
根据假设,新任务数据流式到达。当新任务到达时,首先,选取同等数量的旧任务样本和首次到达的新任务样本作为代表性样本一起训练特征提取网络作为特征提取器[φ],通过使用新旧任务的平衡数据集,特征提取器可以更均衡地提取新旧任务的样本特征,使网络能充分学习新旧任务样本之间的差异,得到更具有代表性的样本特征。
对新任务样本数据提取特征后,采用余弦相似度衡量新旧任务之间的相似程度,其值越大,特征越相似,计算公式如下:
1.3 构建增量训练数据集
由于相似度较高的两个任务,在进行网络训练时,相同部分特征已经被提取到了,所以对于相似度较高的任务,新旧任务越相似,则越应减少旧任务的重放训练样本数量,减少重复训练造成的资源浪费;反之,则应增加旧类的数量,强化旧类知识,减少网络对于新类的偏向。根据新旧任务之间的相似度,令每批次重放旧任务的样本增量为[Doldk?j],其计算公式如下:
1.4 蒸馏损失和分类损失计算
蒸馏损失最早在文献[7]中提出,在增量学习中适用于文献[4,6,8],主要用来促使新的模型和旧的模型在旧类上保持相同的预测能力。增量学习损失包括蒸馏损失[LdΘi;Θi-1;x]和衡量分类准确度的交叉熵损失[LcΘi;x]之和,两者的计算公式分别如下:
1.5 增量优化训练
通过分析不同任务之间的相似性差异,在新任务数据流式到达时设置不同比例的新旧数据进行增量优化训练,整个的训练流程总结如下:
算法1 增量优化训练
输入 1个初始任务(2个类别的分类任务)的数据集[D0],N个新增任务(一个类别表示一个任务)的流式数据集[Di,i∈N]
输出 N+1个任务(N+2个类别)的分类性能
(1) 用数据[D0]训练得到网络模型[Θ0]
(2) 新任務到达,[Di1=500],[Di?j=500],有s个旧类(s的初始值为2)
(3) 新旧类之间进行相似度差异分析,用公式(1)计算新类与每个旧类的余弦相似度[sφXold,φXnew]
(4) 根据相似度差异结果,用公式(2)计算旧类每批次投放的样本增量[Doldk?j,k∈s]
(5) 用公式(3)构建第j个批次的训练数据
(6) 进行增量训练
(7) if各个类别的分类性能达到预期 //测试网络分类性能
(8) then if 还有未完成的任务 then 返回步骤(2) //继续训练下一个增量任务
(9) else 输出N+1个任务(N+2个类别)的分类性能 //已经完成N+1个任务的增量学习
(10) end if
(11) else then 返回步骤(5) //任务分类准确率没有达到要求,继续训练
(12) end if
2 实验研究
选取MNIST数据集中的数字0、1、2在三层卷积神经网络上进行增量学习,以数字0和1作为初始阶段,数字2为新增类别阶段。实验结果如表1所示。
由表1可知,采用本文方法进行增量学习,在第6批次时的平均准确率为0.9818,比重放全部旧数据的准确率0.99稍小,但训练数据量急剧下降,由5923+6741个旧样本变为60+66,显著提升了训练效率。以此类推依次完成数字3-9的增量学习,对比结果如图1所示。
图1中折线图的横坐标为增量学习的各个阶段,纵坐标为平均分类精度,图中结果表明相较于使用全部的新旧类训练数据,使用新的基于任务相似度的增量学习优化方法虽然在分类精度上有所下降,但是结果相差不大,能有效缓解灾难性遗忘的影响,且所使用的训练数据集要远小于使用全部的训练集,减少了训练量,加快了训练速度。
3 结论
针对增量学习中的灾难性遗忘问题,提出了一种基于新旧任务相似度的样本重放学习方法,在尽量保持对旧任务记忆的同时着力提升学习效率,据此选用MINIST数据集进行实验研究,验证了该方法的可行性与有效性,为缓解灾难性遗忘提供了新的解决思路。
参考文献:
[1] McCloskey M,Cohen N J.Catastrophic interference in connectionist networks:the sequential learning problem[J].Psychology of Learning and Motivation,1989,24:109-165.
[2] Shin H, Lee J K, Kim J, et al. Continual learning with deep generative replay[C]. Advances in Neural Information Processing Systems. Curran Associates: New York, 2017:2991-3000.
[3] Wu C S, Herranz L, Liu X L, et al. Memory Replay GANs: learning to generate images from new categories without forgetting[C].Advances in Neural Information Processing Systems. Curran Associates: New York, 2018: 5962-5972.
[4] Rebuffi S A, Kolesnikov A, Sperl G, et al. iCaRL: Incremental Classifier and Representation Learning[C]. Proc of the IEEE Conf on Computer Vision and Pattern Recognition. Piscataway: IEEE Computer Society, 2017: 5533-5542.
[5] Castro F M, Marin-Jimenez M J, Guil N, et al. End-to-End Incremental Learning[C]. European Conference on Computer Vision. Berlin: Springer, 2018:233-248.
[6] Liu Y Y, Su Y , Liu A A , et al. Mnemonics Training: Multi-Class Incremental Learning Without Forgetting[C]. CVPR, 2020:12242-12251.
[7] Hinton G, Vinyals O, Dean J. Distilling the Knowledge in a Neural Network[J]. Computer Science, 2015, 14(7)38-39.
[8] Zenke F, Poole B, Ganguli S. Continual Learning Through Synaptic Intelligence[C].International Conference on Machine Lea rning. Lille: International Machine Learning Society, 2017:3987-3995.
【通联编辑:唐一东】