面向知识蒸馏的自动梯度混合方法①
2024-01-10曹炅宣张曦珊
曹炅宣 常 明 张 蕊③** 支 天** 张曦珊**
(*中国科学技术大学 合肥 230026)
(**中国科学院计算技术研究所 北京 100190)
(***中科寒武纪科技股份有限公司 北京 100191)
0 引言
近年来,得益于越来越深的网络层数和越来越大的参数量,深度神经网络(deep neural networks,DNN)在各类任务中取得了显著的成功。然而,DNN有着较高的计算复杂度和较大的参数储存要求,因此将其部署到运算资源有限的设备或者对即时性要求较高的应用场景变得比较困难,例如智能手机、嵌入式设备和边缘计算。因此压缩大的模型并提高其运行速度变得非常重要。知识蒸馏(distilling the knowledge,KD)就是一种十分有效的模型压缩方法,其通过从大型教师网络中提取有用的知识转移给小型学生网络从而提高小型网络的性能。知识蒸馏的损失函数包含2 个部分,一个是来自于真实标签的任务损失,另一个部分则是来自于教师网络的蒸馏损失。
因此,如何有效地找到2 个损失函数的权重成为了一个待解决的问题。换言之,如何在训练过程中更合理地混合这2 个损失的梯度。现在大多数已有的知识蒸馏方法都是手动调整损失权重,这种方法既繁琐又十分浪费计算资源,并且往往无法达到最佳性能。手动搜索权重的问题主要在于权重的搜索空间范围特别大,而且往往是连续的。例如,根据框架RepDistiller[1],在相同的数据集和师生网络组合下,蒸馏损失权重在0.02(基于概率的知识转移方法(probabilistic knowledge transfer,PKT)[2]) 到30 000(计算相关性的知识蒸馏方法(correlation congruence for knowledge,CC)[3])之间变化。
针对这个问题,可以采用超参数优化(hyperparameter optimization,HPO[4])和多任务学习(multitask learning,MTL)来确定2 个损失的权重,但将这2 种方法应用到知识蒸馏的训练时存在着一些缺陷。在知识蒸馏训练中,存在2 个优化目标:用于任务损失的真实标签以及用于蒸馏损失的教师网络。而知识蒸馏设计的初衷就是使用蒸馏损失作为辅助,帮助作为主要目标的任务损失降低到最小。但超参数优化和多任务学习方法会认为这2 个损失处于一个平等情况,因此会产生大量冗余的搜索空间导致参数调节过程的效率十分低下,并且过分平衡用于辅助的蒸馏损失也有一定可能损害到主优化目标(即任务损失),后续多任务学习的实验也证明了这一点。
为了解决上述问题,本文提出了一种新颖的自动梯度混合方法,该方法可以自动地为知识蒸馏训练找到合适的损失函数权重。本文将寻找合适的损失函数权重的问题转换为寻找2 个损失通过反向传播得到的最佳混合梯度的问题。考虑到在知识蒸馏中,蒸馏损失是任务损失的辅助这一重要的先验知识,自动梯度混合方法可以显著减少混合梯度的搜索空间。通过找到混合梯度的模长和方向从而确定用于更新模型参数的混合梯度。在具体训练过程中,混合梯度的模长用来控制模型参数更新速度,而方向则是决定着模型最终的训练结果。因此自动梯度混合方法通过固定混合梯度模长与任务损失产生的梯度模长相同,用来保证模型迭代的稳定性。在只需要搜索方向的情况下,可以有效地减少混合梯度的搜索空间并提高搜索效率。在确定了混合梯度的模长和方向后,就可以计算出2 个损失函数的权重,从而避免了复杂的手动调节过程。
与现有的手动调节方法相比,本文提出的自动梯度混合方法有效利用了知识蒸馏的先验知识,具有以下几个优点:首先,自动梯度混合方法将混合梯度的模长约束到与任务梯度模长相同,这样能够保证模型训练的收敛稳定性,解耦了梯度向量模长和方向,只需要在方向上进行搜索,显著减少了搜索空间;此外,在进行了该梯度模长的约束后,早期训练轮次的结果与最终训练轮次的结果具备一个较好的保序性,从而通过一个极短时间的预训练即可找到较优的混合方向,从而实现了比手动设置权重更好的性能;最后,自动梯度混合方法是一种简单易用的方法,能够适用于绝大部分的知识蒸馏方法,可以对某种蒸馏方法在某类应用场景下是否有效进行一个快速验证。
为了证明自动梯度混合方法的效果,本文在CIFAR-100[5]和ImageNet-1k[6]数据集上使用Rep-Disitiller[1]框架进行实验,自动梯度混合方法在130个组别中表现超过70%的手动调节结果。在时间上,与超参数优化方法相比,自动梯度混合方法只需要1/10 或者更少的时间就能达到与超参数优化方法相当的精度。
1 相关工作
1.1 知识蒸馏
知识蒸馏将大的、笨重的教师网络的知识转移给更小、更敏捷的学生网络中,从而能够有效提高学生网络的性能。Hinton 等人[7]提出了这种方法,该方法使用温度来修正教师网络输出的softmax,使其作为软标签来指导小型的学生网络。目前有3 种不同类型的知识蒸馏,分别是基于响应、基于特征和基于关系的知识蒸馏方法[8]。基于响应的方法[7]旨在通过使用教师网络的logits 作为知识来直接模拟教师网络的最终预测。基于特征的方法[9-12]则是专注于匹配教师网络和学生网络中间层的特征。基于关系的方法[1,3,13-14]认为不同层或数据样本间的关系能有助于蒸馏。然而,现有绝大部分方法都使用手动调整来找到合适的任务损失权重和蒸馏损失权重,这既繁琐又十分耗时,而且往往无法达到最佳性能。
1.2 超参数优化
超参数优化(HPO)方法是一类寻找最优的超参数组合的方法。这些方法可以分成3 类。第1 类是穷举搜索,例如随机搜索和网格搜索。网格搜索将超参数空间划分为不同的网格并运行每个网格对应的参数组合以此找到最佳参数。这种遍历式的搜索方法由于没有对搜索空间进行任何裁剪,因此非常耗时。为了使得搜索过程效率更高,研究人员提出了第2 类启发式搜索方法,该类搜索方法可以在搜索过程中根据可用信息(例如之前训练的结果)选择后续最佳的搜索分支。超参数优化方法中包含有一些经典的启发式搜索方法,例如朴素进化和模拟退火。最近,研究人员也提出了Hyperband[15]、Popluation-Based Training[16]等新的启发式方法。第3 类是贝叶斯优化,它通过条件概率建模来预测给定超参数的最终性能,例如序列贝叶斯优化(sequential Bayesian optimization hyperband,BOHB)[17]、树形Parzen 估计方法(tree-structured Parzen estimator approach,TPE)[18]等。与手动设置参数相比,超参数优化方法的调节器理论上可以节省一些搜索时间,但是仍然非常耗时。
1.3 多任务学习
多任务学习(multi-task learning,MTL)是指通过使用所有任务和其他一些任务中包含的知识来共同学习多个任务,以此来提高每个任务性能的一种训练方法。多任务学习方法包括2 个方面[19]。一些多任务学习方法设计深度学习多任务架构,包含有设计侧重于编码器[20-21]或侧重于解码器[22-23]的架构。其他的一些多任务学习方法则是侧重于平衡多个任务的训练优化,例如Uncertainly[24]、GradNorm[25]、DWA[26]、DTP[27]、Multi-Objective Optim[28]。绝大部分多任务学习方法都会等权重优化所有任务或者是所有损失函数,因此它可能会和知识蒸馏中将任务损失视为主要损失、将蒸馏损失视为辅助的理念相冲突。
2 自动梯度混合
2.1 问题定义
为了提高小型学生网络的性能,知识蒸馏除了利用来自于真实数据的监督外,还额外引入了来自于大型的教师网络中的有益的知识。因此,总的损失函数由来自于真实标签的任务损失和来自于教师网络的蒸馏损失构成,公式为
这里Lkd是总的知识蒸馏损失函数,Ltask是任务损失,Ldistill是蒸馏损失。α和β是任务损失和蒸馏损失的缩放系数。为了获得合适的系数α和β,绝大部分已有的知识蒸馏方法都是通过手动调节方法来进行搜索,这类方法非常繁琐又耗时,并且往往无法使学生网络拥有最佳的性能。为了解决这个问题,本文提出了一种自动梯度混合方法来自动高效地找到损失权重。
假设在整个训练过程中,第t轮的模型参数更新迭代时,损失函数对模型参数求导后得到的梯度被用来迭代模型参数,公式为
2.2 高效搜索混合梯度
为了有效搜索最优混合梯度,需要尽可能地缩小搜索空间。在这项工作中,本文利用了知识蒸馏中的一个重要的先验知识,即任务损失是主要优化目标,而蒸馏损失是任务损失的辅助。因此,混合梯度Gkd应当与任务梯度Gtask更加相关,蒸馏梯度Gdistill用来做一个细化调整。本文通过确定混合梯度Gkd的方向和模长来找到这个混合梯度。一般而言,在使用梯度来更新模型参数的过程中,梯度向量具有2 个自变量,一个是方向,另一个则是模长,两者的功能具有一定差异。梯度的模长主要影响着模型参数的更新速度,从而控制模型收敛,当模长太长时,会出现梯度爆炸使得模型无法收敛或者是在最优值附近徘徊的情况;而模长过短时,模型收敛会非常缓慢,找到最优值的时间过长,也有可能陷入到某个局部最优点中。梯度的方向则是决定着模型参数的更新方向,决定模型最终的收敛位置能否在相应的指标上取得好的效果(如分类任务中的准确率,检测任务中的mAP 等)。在非蒸馏训练中,模型仅使用任务损失产生的梯度就能训练出来一个稳定的结果。本文基于上述先验知识,为了提高效率减小搜索空间,以及保证模型训练的收敛稳定性,自动梯度混合方法将混合梯度的模长约束到与任务损失梯度模长相同,公式为
在实现该约束后,可以很方便地将学生网络的非蒸馏训练版本的超参数,如学习率、权重衰减等,方便应用到本文中使用的蒸馏训练上。因此可以通过对Gkd的模长约束得到一个稳定的训练过程。
在确定了模长大小后,自动梯度混合方法只需要在搜索空间中搜索Gkd梯度方向,该梯度方向由任务梯度Gtask和蒸馏梯度Gdistill决定。如图1 所示,Gkd方向的搜索空间为Gtask和Gdistill之间的角度空间。θ为Gtask和Gdistill夹角大小:
图1 梯度混合示意图
假设Gtask和Gkd的夹角为λθ,只需要在λ∈[0,1] 这个范围内进行搜索。在这种方式下,由于不需要对Gkd的模长进行搜索,整个搜索空间得到大幅度缩减,同时对最优方向的搜索可以保证混合梯度Gkd的有效性。
通过搜索得到λ后,可以用如下公式表示Gkd的方向:
使用式(3)和(4),可以得到:
联立式(5)~(7),可以解得损失权重系数α和β为
2.3 热身策略
如式(9)所示,损失权重系数α和β取决于λ。λ的有效值为[0,1]。当λ等于0 时,蒸馏损失对混合梯度没有任何影响;当λ等于1 的时候,混合梯度方向会完全遵循蒸馏梯度的方向。此外,实验结果表明自动梯度混合方法在训练早期和后期的性能(在分类任务中为准确率)有着良好的保序性。因此,为了进一步提高搜索过程中实验的效率,本文使用训练早期的训练效果来预测最终的性能。在具体操作中,本文在搜索空间中对λ进行一个早期的搜索来作为预热训练模型。然后选择性能最佳的一个作为λ的最佳值。之后可以采用式(9)来计算损失权重α和β,并且使用它们来完成训练。搜索和训练模型的整个过程如算法1 所示。
3 实验和结果
本节中,本文将提出的自动梯度混合方法应用在被广泛使用的图像分类数据集CIFAR-100[5]和ImagNet LSVRC 2012[6]上。此外,本文使用的Rep-Distiller[1]框架基于Pytorch,其模型库中包含有13种流行的蒸馏方法。在实验中,本文遵循RepDistiller 默认的超参数设置,如训练轮次、学习率、优化器等。在自动梯度混合方法中,预热轮次设置为5。作为对比实验,本文使用RepDistiller 中给出的手动调整的损失权重的训练结果作为基线。
3.1 CIFAR-100 上的实验结果
本文在KD[7]、Fitnets[11]、SP[29]、AT[12]、CC[3]、VID[29]、RKD[13]、PKT[3]、FT[10]和NST[9]这10 种蒸馏方法上进行实验。此外,实验还包含有7 个相似架构的师生网络组合和6 个不同架构的师生网络架构,即整个实验包含有10 ×13 个小的实验。
结果如表1 所示,可以发现自动梯度混合方法和手动方法比较,无论是在教师网络架构和学生网络架构相似的VGG13-VGG8 和ResNet110-ResNet32亦或者是ResNet32x4-ShuffleNetV2 和VGG13-MobileNetV2 这类架构差异很大的网络上都有比较好的效果。总结表1 的结果可以发现,自动梯度混合方法在70%的蒸馏组合上都要比手动调节的方法表现得更好。
表1 在数据集CIFAR-100 上使用手动调节(Manual)和 自动梯度混合方法(AGB)在10 种不同的蒸馏方法和13 种不同的师生网络组合的Top-1 准确率(%)
3.2 ImageNet-1K 上的实验结果
本文使用KD、CC、对比表示知识蒸馏方法(contrastive representation distillation,CRD)和注意知识蒸馏方法(attention on distillation,AT)在Image Net-1K数据集进行实验。因为RepDistiller 框架没有ImageNet-1K 对应代码,所以本文在ImageNet-1K 上复现了这4 种方法。超参数和手动调整的损失权重是按照另一个蒸馏框架TorchDistil 设置的。本文使用Pytorch 团队发布的模型ResNet34 和ResNet18 作为教师和学生网络,并遵循TorchDistill 的ImageNet 训练设置。
表2 展示了自动梯度混合方法和手动参数设置方法在以ResNet34 和ResNet18 作为师生网络组合上的top-1 准确度。对于KD、CC 和AT 方法,自适应梯度混合方法可以获得更好的性能,对于CRD 方法,自动梯度混合方法也可以达到和手动设置接近的性能。因此,ImageNet-1K 上的实验有效证明了自动梯度混合方法的有效性。
表2 自动梯度混合方法(AGB)和手动调整(Manual)在ImageNet-1k 上的Top-1 准确度(%),其中教师网络是ResNet34(top-1 准确度73.314%),学生网络是ResNet18(top-1 准确度69.76%)
3.3 和超参数优化方法的比较
本文在CIFAR100 上使用自动梯度混合方法和Microsoft Neural Network Intelligence (NNI)的3 个不同的超参数优化调节器进行了对比。这些超参数优化方法包括有启发式搜索方法模拟退火(simulated annealing)、Hyperband[15]和贝叶斯优化方法TPE[18]。选择VGG13 和VGG8 作为师生网络,并使用AT 蒸馏方法进行实验,在超参数优化方法中,参照式(1),设置α等于1,β的搜索空间为0.02 到30 000。
图2 显示了3 个超参数优化调节器和自动梯度混合方法的比较实验。可以观察到自动梯度混合方法只需要极少训练的时间就能达到非常高的精度。相比之下,在运行同样的时间中,超参数优化方法只能实现更低的精度。尽管超参数优化方法在最终的结果中达到了与自动梯度混合方法相当或者略高的精度,但它们需要更多的时间来进行搜索,这是非常低效的。
分析超参数优化方法出现的问题,可以发现无论是手动调节、超参数优化或者是一些简单约束情况,都会导致超参数搜索过程变得漫长而复杂。本质上,这是由于这类方法在搜索超参数时会将总梯度向量模长和方向进行耦合,同时去搜索梯度向量的方向和模长,会影响模型的收敛性,并出现两类冗余搜索的情况:(1)搜索到合适的方向而模长过长或过短,导致出现模型无法收敛;(2)搜索到合适的模长而方向不对,这样会影响模型最终的收敛位置,即影响模型最终的结果。而当一些更为奇怪的约束使得总梯度向量的方向与模长耦合得更加紧密时,甚至无法搜索到对应合适方向。
3.4 和多任务学习方法的比较
将Uncertainly 和GradNorm 这2 种无超参数的多任务学习方法与自动梯度混合方法进行对比实验。本文对所有的10 种蒸馏方法进行了实验,所有的13 种教师学生网络组合与3.1 节中的相同。
如表3 所示,自动梯度混合方法应用到绝大多数蒸馏方法中都优于这2 种多任务学习方法。多任务学习方法将蒸馏损失和任务损失平等对待,忽略了知识蒸馏的重要先验知识,即任务损失是起到主导作用的,而蒸馏损失是用于辅助的。因此,多任务学习方法可能会为了最大限度地降低蒸馏损失而牺牲了性能。还可以发现,当使用GradNorm 时,大多数蒸馏方法的性能都很差。这是因为GradNorm 完全忽略了任务损失应该为主导地位。而且,与任务损失相比,蒸馏损失通常非常大或者非常小。例如,在CC 中,蒸馏梯度的模长约为任务梯度的100 倍,
表3 多任务学习方法GradNorm 和Uncertainly 在CIFAR-100 上与自动梯度混合方法(AGB)相比的Top-1 测试准确度(%)。由于训练过程中的梯度爆炸,一些方法显示出非常差的准确性或无法训练出有效的结果(用表示)。null 表示此蒸馏方法不支持多任务学习方法。
而在PKT 中,蒸馏梯度的模长约为任务梯度的0.001倍。因此,GradNorm 简单地平衡2 个损失将会导致整个训练过程不稳定。相比之下,自动梯度混合方法将混合梯度的模长限制为与任务梯度的模长相同。因此,自动梯度混合方法在获得稳定训练过程的同时,可以保留任务梯度占据主导地位这一重要信息。
3.5 保序性证明
本文验证了在自动梯度混合方法中训练早期和训练后期准确率的保序性。在CIFAR-100 上使用AT 蒸馏方法进行这些实验,在NNI 上用VGG13 作为教师网络,VGG8 作为学生网络。计算早期(第5轮)的准确率和整个训练结束的最终准确率之间的相关系数。本文还对手动调节方法进行了这些实验,α设置为1,β从0.003 变化到30 000。为了公平地比较,本文选择结果接近收敛时的最后80 次实验来验证相关性。
如图3所示,下图为自动梯度混合方法,其相关系数为0.724,远高于上图中手动调节方法的0.410。这个实验说明了使用自动梯度混合方法时早期轮次表现较好的设置同样可以运用到晚期轮次。因此,预热策略可以在不损失性能的前提下大幅提升自动梯度混合方法的效率。
图3 最佳精度与早期轮次精度之间的相关性
3.6 消融实验
本文就预热阶段设置的热身轮次和预热阶段用于离散化的步长进行了消融实验。在CIFAR-100上使用KD 蒸馏方法进行实验,教师网络为Res-Net32x4,学生网络为ResNet8x4。
图4 显示了准确率、时间开销与步长的关系。可以看到,当步长从0.2 变小后,时间开销增大,对应的结果略有上升;而当步长变大后,实际上的节省的时间相当有限,而性能也会出现一定程度的下降。图5 则显示了准确率、时间和热身轮次的关系。可以发现,与前面步长类似,选取更小的热身轮次并不会导致运行时间有一个显著的变小。而当热身轮次提升后,时间开销增大了,对于实验的准确率也没有提升太多。因此本文取的热身轮次和步长并不具备特殊性,取附近的几个值结果差异不会太大,这也说明了是前面模长约束在方法中起到了主要的作用而预热的等间距选取最优的策略只是用于辅助的。
图4 准确率、时间与步长之间的关系
图5 准确率、时间与热身轮次之间的关系
3.7 自动梯度混合方法的高效性
图2 中的结果也显示了自动梯度混合方法的高效性。图2 中圆点表示一次超参数优化方法实验的准确性。随着训练实验的增加,每条虚线表示 超参数优化方法的最佳准确性。三角形标记表示自动梯度混合方法的结果,该方法需要大约1.50 次实验时间才能达到72.48%的准确率。在知识蒸馏中寻找损失权重时,手动调整会受到大的搜索空间的影响。通过使用贝叶斯优化或者是其他算法改进搜索过程,超参数优化方法会高效一些,但是仍然有着比较大的搜索空间。相比之下,自动梯度混合方法通过约束混合梯度的模长并仅仅在预热阶段在方向上进行搜索,从而显著减少了搜索空间。如图2 所示,超参数优化方法需要10 次以上的实验才能达到与自动梯度混合方法相当的精度。因此,与超参数优化方法相比,自动梯度混合方法效率更高。
4 结论
本文提出了一种自动梯度混合方法,可以有效地为绝大部分知识蒸馏方法找到合适的损失权重。利用蒸馏损失是用于辅助任务损失这一先验知识,自动梯度混合方法通过减少超参数搜索空间来优化搜索过程。自动梯度混合方法只搜索梯度方向,即2 个损失梯度之间的角度,同时将混合梯度的模长约束为与任务损失梯度模长相同。本文在13 种不同的师生网络组合之间对10 种不同的知识蒸馏方法进行了实验。自动梯度混合方法在使用更少的运算资源的前提下在70%的蒸馏方法上性能超过了手动调节方法,这说明自动梯度混合方法具有更好的效果以及更高的效率。本文工作的前提是假设当有多个蒸馏损失时,所有的蒸馏损失共享相同的权重。未来,可以将本文工作扩展到具有多种蒸馏损失的情况。