APP下载

基于隐层相关联算子的知识蒸馏方法

2022-09-26吴豪杰王妍洁蔡文炳林绍辉

关键词:中间层算子准确率

吴豪杰 ,王妍洁 ,蔡文炳 ,王 飞 ,刘 洋 ,蒲 鹏 ,林绍辉

(1.中国电子科技集团公司第二十七研究所,郑州 450047;2.北京跟踪与通信技术研究所,北京 100094;3.中国人民解放军63726 部队,银川 750004;4.华东师范大学 计算机科学与技术学院,上海 200062;5.华东师范大学 数据科学与工程学院,上海 200062)

0 引 言

近年来,随着深度学习与图形处理器(Graphics Processing Unit,GPU)硬件的不断发展,卷积神经网络(Convolutional Neural Networks,CNNs)已经在诸多人工智能领域取得了显著的成效,如区块链[1]、图像分类[2]、目标检测[3]等.得益于其大规模的数据量与强大的特征提取能力,CNNs 在某些任务上甚至已经超过了人类识别的准确率[4].同时,GPU 硬件的高速发展大大提高了网络模型的计算效率.

随着网络模型性能的提升,其计算开销与存储量也在不断增加.如AlexNet[2]模型,其具有0.61 亿网络参数和7.29 亿次浮点计算量(Floating-point Operations per Second,FLOPs),占用约240 MB 的存储空间.对于被广为使用的152 层残差网络(Residual Network-152,ResNet-152)[4]具有0.57 亿网络参数和113 亿次浮点计算量,占用约230 MB 的存储空间.庞大的网络参数意味着更大的内存占用,而巨大的浮点计算量意味着高昂的训练代价与较小的推理速度.这使得如此高存储、高功耗模型无法直接在资源有限的应用场景下应用,如手机、无人机、机器人等边缘嵌入式设备.因此,在保持模型识别准确率的前提下,对于网络模型进行压缩与加速,以适应边缘设备的实际要求,成为了当前计算机视觉领域火热的研究课题.与此同时,也有研究表明[5],在巨大的网络参数内部,并不是所有的结构和参数对于网络的识别预测能力都起到决定性作用,这使得模型压缩技术,即移除冗余性参数和计算量成为了一种有效的解决方案.

当前主流的模型压缩方法可以分为5 种,分别为参数剪枝、参数量化、低秩分解、轻量型网络结构设计和知识蒸馏(Knowledge Distillation,KD).知识蒸馏方法可以直接设定压缩后模型的结构、计算量和参数量,以及不引入额外的计算算子,这使得知识蒸馏技术得到了广泛关注.因此,本文也着重研究基于知识蒸馏的模型压缩方法.知识蒸馏方法将较大和较小的网络分别定义为教师网络和学生网络 (也称之为压缩后网络).其主要思想在于,通过最小化该两个网络输出分布差异,来实现网络间的知识迁移,使得学生网络尽可能地获得教师网络的知识,提高学生网络的准确率.从而,学生网络可以在维持其参数量不变的情况下提升性能,尽可能逼近甚至有可能超越教师网络的性能.传统的知识蒸馏方法是将网络的输出分布作为知识在网络间进行迁移,随着该研究领域的进一步发展,研究发现[6],利用其他一些具有代表性的表征信息或知识在网络间进行迁移或蒸馏,可以获得比传统知识蒸馏方法更好的效果.知识蒸馏方法大致又可以分为: ①基于网络输出层的知识蒸馏方法;② 基于网络中间层的知识蒸馏方法;③基于样本关系之间的知识蒸馏方法.

本文提出了一种新的基于隐层相关联算子的知识蒸馏 (Correlation Operation Based Knowledge Distillation,CorrKD) 方法,通过计算教师网络与学生网络各自隐含层之间的关联性,挖掘出更有效的知识表征,从而将教师的知识表征迁移到学生的知识表征中,提高学生网络的判别性.该方法的核心是利用了被广泛应用于光流[7-8]、图像匹配[9]等领域内的相关联算子,用于提取网络中间层的知识表征.相关联算子的特性在于,可以很好地表征两个特征之间的匹配程度,并反映其特征的变化过程.首先,本文对于网络中每个阶段的输入特征与输出特征,利用相关联算子进行建模与知识提取,有效获得了图像特征的学习变化信息.然后,将教师网络每阶段通过相关联算子提取出的表征信息作为知识,迁移到学生网络中,提升学生网络判别性和学习有效性.

在CIFAR-10 和CIFAR-100 分类数据集评测结果中,相比其他中间层知识蒸馏方法,本文所提出的方法取得了较好的效果.同时,本文所提出的方法在减小网络的计算量和参数量的同时,能够有效逼近原始网络的准确率.

1 相关工作

1.1 主流模型压缩方法

除本文将详细介绍的知识蒸馏方法外,其他主流的模型压缩方法有: ①参数剪枝[10-11],该方法的主要思想在于,通过对已训练好的深度神经网络模型移除冗余、信息量较少的权值,减少网络模型的参数,进而增大模型的计算速度和减小模型所占用的存储空间,实现模型压缩;② 参数量化[12-14],该方法的主要思想是一种将多个参数实现共享的直接表示形式,其核心思想在于,利用较低的位来代替原始32 位的浮点型参数,从而缩减网络存储和浮点计算次数;③低秩分解[15-16],该方法的核心思想在于,利用矩阵或张量的分解技术对网络模型中的原始卷积核进行分解.一般来说,卷积计算是网络中复杂度最高且最为普遍的计算操作,通过对张量进行分解从而减小模型内部冗余性,实现模型压缩;④ 轻量型网络结构设计,轻量型网络结构设计的方法主要是改变了卷积神经网络的结构特征,提出了一些新颖的轻量计算模块或操作,从而精简网络结构,增大处理速度.如基于深度可分离卷积的MobileNet[17],利用神经网络结构搜索得到的EfficientNet[18]等.

1.2 知识蒸馏方法

知识蒸馏方法[19]指利用教师网络中的知识表征为学生网络提供指导,以提高学生网络的性能.传统的知识蒸馏方法通过最小化教师网络和学生网络类别输出分布的KL (Kullback-Leibler)散度来实现蒸馏.除了在输出层外,网络中间层的特征信息也被应用到知识蒸馏方法中.

中间层特征知识的构造.Romero 等[20]提出的FitNet 是较早利用中间特征信息进行知识蒸馏的方法,其目标是使经过奇异值分解的学生网络尽可能学习教师网络中间层的特征信息.随后,Zagoruyko 等[21]提出在网络中间层引入注意力机制,将每层的注意力特征作为可学习的知识迁移到学生网络中.近年来,随着自注意力模型被广泛运用到变形器[22]中,进而获得人工智能领域各项任务的性能突破,相关知识蒸馏方法[23-24]通过对齐教师与学生的自注意力矩阵实现知识迁移.Yim 等[25]提出了FSP (Flow of Solution Procedure)方法,将网络中每层之间的数据流动关系作为知识,由教师网络迁移到学生网络中.除此之外,样本之间的关系特征也被发现可以凝炼出更好的知识表示.例如,Park 等[26]提出RKD (Relational Knowledge Distillation)知识蒸馏框架,对于不同样本网络输出的结构关系进行建模,将关系特征进行知识迁移.此外,Liu 等[27]通过将教师网络特征空间映射到由顶点与边构成的图表示空间中,然后对齐教师与学生网络的顶点以及它们边的对应信息实现知识蒸馏.Tung 等[28]利用网络中间层每个样本之间的相似度信息进行知识迁移.Kim 等[29]提出在教师网络的最后一层特征中提取便于学生网络理解的转移因子,将知识传递给学生网络.对于教师网络和学生网络中间层特征不一致的情况,Heo 等[30]提出了使用 1×1 卷积进行维度对齐,并构建教师网络激活边界作为中间层知识迁移到学生网络中.不仅如此,特征图的雅可比梯度信息[31]也可以作为中间层特征知识表示.近年来,出现了一些在输出层特征进行对比学习[32]或基于自监督[33]的知识蒸馏方法,分别用于挖掘教师网络和学生网络对于不同样本之间的关系,从而将教师网络的关系知识迁移到学生网络中.不同于以上知识蒸馏方法,本文所提出的基于相关联系数的知识蒸馏方法作用于每阶段中间层特征信息,从而获得每阶段中间特征变化信息,能更好构建知识表征,提高学生网络的学习性能.

使用优化训练策略进行中间层知识蒸馏.近年来,大量生成对抗思想被应用到中间层知识蒸馏中,提高知识蒸馏性能.例如,Su 等[34]引入了任务驱动的注意力机制,将教师网络和学生网络各自高层信息嵌入低层中,实现中间层信息的迁移,同时加入判别器用于增强学生网络最后输出特征的鲁棒性.类似地,Shen 等[35]提出了基于对抗学习的多教师网络集成蒸馏框架,利用自适应池化操作对齐一个学生与多个教师集成网络的中间层输出维度,同时利用生成对抗策略对池化的中间层特征进行对抗训练,提高了知识蒸馏性能.Chung 等[36]提出了基于中间层特征图的在线对抗蒸馏框架,设计教师网络和学生网络的判别器,用于共同学习和对齐这两个网络在训练过程中的特征图分布的变化情况.Jin 等[37]提出了一种路线限制优化策略,预先设定好教师网络训练的中间模型状态,并通过逐步对齐学生网络与其中间层特征分布,使得学生网络获得更好的局部最优解.

2 方 法

2.1 知识蒸馏方法

知识蒸馏方法[19]认为在数据的网络输出中,每一个数据的预测概率结果都可以看作是一个分布,不仅关注于置信度最高的类别所对应的结果,而且对于预测错误结果的置信度概率也具备一定的网络知识.在传统分类任务所使用的交叉熵损失函数中,只会关注对应于正确类别的概率值,对于其他类别所对应的概率是直接丢弃,没有利用的,Hinton 等[19]将其称作是暗知识.在知识蒸馏的过程中,学生网络所学习到的,不仅是预测正确的类别所对应的概率值结果,而且包括教师网络所学习到的暗知识.

在具体的实现过程中,将教师网络记为ft,学生网络记为fs,将输入记作x,教师网络和学生网络的模型输出结果分别记为zt和zs,且zt=ft(x),zs=fs(x),zt,zs∈Rd,d为总类别数.对于网络得到的输出分布,利用 Softmax对此进行归一化,得到概率分布.同时,还引入了温度分布参数τ用来平滑该层的输出分布,以强化网络输出的概率分布中所学习到的知识,通过温度平滑后的网络输出被称为软目标.对此,以教师网络为例,对于第i个输入样本xi,其软目标用公式表示为

式(2)中:n表示样本总个数,KL(ps||pt) 定义为学生网络输出分布与教师网络输出分布之间差异,具体公式表示为

所以,在学生网络训练的过程中,教师网络的软目标与真实标签共同起到监督作用.传统知识蒸馏损失函数为

式(3)中:LCE为传统的学生网络输出与真实标签的交叉熵损失函数;α为平衡因子,用于权衡LCE和LKL的重要性比例.

2.2 相关联算子

相关联算子[7]被广泛应用到光流、图像匹配、目标跟踪领域中,用于描述两张图像或两个特征之间的匹配程度(图1).对于三维的图像特征张量A和B,其尺寸为C×H ×W,C、H和W分别表示其特征图的通道数、高度与宽度.特征张量A中给定位置 (i,j)的特征为PA(i,j)∈RC,需要计算其与特征张量B中所对应位置图像块的特征相似度,这里所对应的图像块以 (i,j)为中心,大小为k×k,将该区域内的像素位置记为 (i′,j′) ,所对应的特征为PB(i′,j′) ,与PA(i,j)类似,该像素特征均为C维向量.因此,可以通过计算内积的方式得到对应像素特征之间的相似度,由此得到相关联算子φ,其计算公式为

图1 相关联算子示意图Fig.1 Illustration of correlation operation

式(4)中:⊙表示向量内积,为归一化系数.由此,可以得到特征张量A和B之间的相关联算子,可以将其记为φ(A,B)∈Rk2×H×W.所以,对于给定的两个三维图像特征张量,可以通过计算像素特征与图像块中每个像素之间的相似度,得到尺寸为k2×H ×W的相关联算子,用于反映特征之间的相似程度或匹配程度.

2.3 基于隐层相关联算子的知识蒸馏方法

借助相关联算子,可以计算网络模型隐层中尺度相同的两个特征张量之间的特征,用以反映特征的匹配相似程度,并利用其进行知识迁移 (图2).图2 中的KL 损失LKL和LCor损失分别被定义于式(2)和式(5)中,xi和分别为第i个输入样本和该样本增强变化后的表示.

图2 基于隐层相关联算子蒸馏方法的整体框架Fig.2 Illustration of intermediate CorrKD framework

通常,网络模型会根据其特征图空间尺寸大小的不同而划分成不同的阶段,换句话说,在相同的网络阶段内,其中间特征的维度尺寸都是相同的.因此,可以将每个阶段的第一层特征与最后一层输出特征作为相关联算子中的特征张量A和B.该相关联算子的计算可以很好地反映出模型每个阶段对于数据的处理变化过程,成为非常有效的知识表征.因此,可以将相关联算子计算结果用作知识蒸馏的表征信息,由教师网络对学生网络进行指导.假设网络有N个阶段,教师网络和学生网络的第i个阶段的第一层输入特征分别记为,最后一层的输出特征分别记为Fit2和Fis2,其知识迁移的过程可以利用LCor损失进行约束,对此,基于隐层相关联算子的知识迁移损失函数可以表示为

式(5)中:λi,i=1,2,···,N表示第i阶段的权重因子,||·||2为L2范数.为了更好形成多样的知识表征,在本文中引入数据增强和变化[4](如旋转、翻转、颜色变化等),可以更有效地将隐含层的相关联算子的知识迁移到学生网络中,从而产生更好的效果.通过结合了教师网络中传统知识蒸馏损失函数 (式 (3))和隐层相关联算子的知识迁移损失函数 (式 (5)),可以得到该知识蒸馏方法完整的训练损失函数公式为

式(6)中:β为超参数,用于控制3 个损失 (LCE、LKL和LCor) 的平衡性.在训练过程中,本文直接使用梯度下降法优化式 (6),选择学生网络进行测试,并计算出学生网络的准确率作为该方法的评测效果.

3 实 验

3.1 实验数据

本文在两个经典的分类公开数据集CIFAR-10 与CIFAR-100 上进行了实验,均包含6 万张长宽尺寸均为32 的图像,其中5 万张用于训练,剩下的1 万张用于测试,他们的分类类别数分别为10和100.

3.2 实验设置

本文所提出的方法使用Pytorch 在单张GPU 上进行实现,对于两种数据集均采用随机梯度下降方法进行优化.在训练中,图像批量大小设置为64,学习率设置为0.05,动量设置为0.9,权重衰减系数为0.000 5.对于教师网络,利用标准交叉熵损失函数进行训练,训练迭代次数为240,其学习率分别在第150、180、210 次迭代时,分别缩小为原来的1/10,训练完成后将教师网络进行保存,存储于本地磁盘中.

对于学生网络,需要先读取教师网络的模型参数,利用所提出的损失函数式 (6) 进行训练,模型训练优化器与学习率设置均与教师网络一致,训练迭代次数设为300,其学习率分别在第180,220,260 次迭代时,分别缩小为原来的1/10.

在相关联算子的计算过程中,需要引入数据增强,首先,对于图像进行随机旋转与翻转.其次,在图像色彩上从灰度转化、色彩抖动、高斯模糊等操作中随机选取一种对图像进行色彩上的增强.在相关联算子的计算过程中,参数k=7,对于所选取的网络模型,其结构均为4 个阶段,也就是式 (5) 中的N=4,同时将每个阶段的权重设为相等,也就是λi=1 .设置式 (1) 中的τ=4 .最后,设置式 (6)中的α=0.2,β=5 .

3.3 实验结果

本文所提出的方法在多种模型结构上进行实验验证,选取ResNet[4]与WideResNet[38](WRN)作为网络主干,并在多种教师网络与学生网络组合上进行实验.表1 总结了4 组教师网络与学生网络的参数量与计算量信息.在表2 和表3 中,总结了本文所提出方法的性能效果,其中本文所提出的基于隐层相关联算子的知识蒸馏方法记为CorrKD,仅利用中间层式 (5) 与交叉熵损失训练得到的学生网络方法简称为Corr,KD 表示仅利用式 (3) 进行训练的传统知识蒸馏训练结果.注意到表2 与表3 中的第3 和第4 列分别表示教师网络与学生网络在正常情况下训练得到的基准准确率结果 (即只使用交叉熵损失函数).KD 展示了学生网络在利用式 (3) 训练得到的传统知识蒸馏方法的结果.

表1 实验所用模型参数量与计算量信息Tab.1 Model parameters and FLOPs information used in the experiment

从实验结果来看,单纯基于中间隐层相关联算子的知识迁移方法可以对于学生网络的训练带来一定的促进作用,但效果并不明显.通过结合了输出层的传统知识蒸馏方法KD 之后,在学生网络的分类正确率上,获得了很好的性能提升.在蒸馏教师网络WRN40-2 时,在CIFAR-10 上学生网络WRN16-2 的网络参数和网络计算量都约为原来教师网络WRN40-2 的31.8%,即参数量 (教师网络参数量为2.2 M,学生网络参数量为0.7 M,教师网络计算量为329.0 M,学生网络计算量为 101.6 M).如表2 所示,由本文所提出的CorrKD 方法得到的学生网络准确率只下降了0.5 百分点 (教师网络准确率为95.2%,学生网络使用CorrKD 方法准确率为94.7%).对于类别个数更多的CIFAR-100 上,同样蒸馏的网络选择,由本文所提出的CorrKD 方法压缩WRN40-2 后的网络计算量和参数量约是压缩前的31.8% (表1),准确率只下降1 百分点 (表3 中教师网络准确率为76.8%,由CorrKD 方法得到的准确率为75.8%).由此可见,本文所提出的方法在准确率有限下降的情况下,模型能够获得显著的压缩比,压缩后形成的学生网络能够有效嵌入受限移动设备端中.

表2 CorrKD 在CIFAR-10 上实验结果Tab.2 Experimental results of CorrKD on CIFAR-10

表3 CorrKD 在CIFAR-100 上实验结果Tab.3 Experimental results of CorrKD on CIFAR-100

在CIFAR-100 上,也可视化了本文所提出的CorrKD 方法对于蒸馏WRN16-2 的训练损失的变化以及测试准确率的变化.如图3 所示,随着训练的回合数的增加,完整训练损失Lo逐步减小,同时测试准确率逐渐提升.该训练结果验证了本文所提出的方法在训练上的稳定性与有效性.

图3 完整训练损失Lo 和测试准确率变化曲线Fig.3 Curves of overall training loss Loa nd test accuracy with respect to the epoch number

在CIFAR-100 评测数据集上并以WideResNet 为主干网络,将本文所提出的方法与其他经典基于中间层的知识蒸馏方法进行对比,包括FitNet[20],AT (Attention Transfer)[21],SP (Similarity-Preserving)[28]和FT (Factor Transfer)[29].为保证公平性,上述中间层蒸馏方法都展示与传统KD 相结合训练的实验结果,各方法所得到的结果对比如表4 所示.从实验结果来看,本文所提出的知识蒸馏方法在WideResNet 模型结构上,和其他中间层的知识蒸馏方法相比,取得了较好水平.例如,在学生网络为WRN16-1 时,本文所提出的方法和AT 方法相比,准确率提高了0.1 百分点 (CorrKD 准确率为74.6%,AT 准确率为74.5%),同时,与教师网络WRN40-2 相比,准确率降低2.2 百分点 (CorrKD 准确率为74.6%,WRN40-2 准确率为76.8% (表3)).

表4 CorrKD 与其他知识蒸馏方法在CIFAR-100 上准确率对比Tab.4 Accuracy comparison between different KD methods and CorrKD on CIFAR-100

3.4 参数敏感性分析实验

本节主要探索部分超参数对于实验效果的影响,主要包括相关联算子中参数k的影响以及完整的训练损失函数中参数α,β的影响.实验均在CIFAR-100 上进行,教师网络结构选取WRN40-2,学生网络结构选取WRN16-2.对于3 组参数的实验结果分别如表5 和表6 所示,“教师网络→学生网络”表示教师网络蒸馏学生网络所使用的网络模型.在k相关的实验中,固定α=0.2,β=5 ;同理,在α与β相关的实验中,固定其他两个参数.从实验结果看出,实验中所选取的参数k=7,α=0.2,β=5 均为最佳参数.

表5 相关联算子参数 k 实验结果对比Tab.5 Comparison with different values of k in the correlation operation

表6 完整训练损失 Lo 中参数 α ,β 实验结果对比Tab.6 Comparison with different values of α ,β in the overall training lossLo

4 总结

本文提出了一种新的基于隐层相关联算子的知识蒸馏方法,首次将用于光流中的相关联算子计算操作运用到模型中间隐含层的特征提取中,相关联算子可以对特征之间的匹配程度或变化过程进行有效建模,反映模型中间层的表征信息.同时在数据增强的作用下,进行中间层的知识迁移,结合输出层的传统知识蒸馏方法,构成了本文所提出的全新知识蒸馏框架.实验表明,本文所提出的知识蒸馏方法在两种公开数据集上均取得了优越性能,并在WideResNet 模型上取得了同类型中间层知识蒸馏方法中的最优水平.在未来的研究中,可以考虑将该模型中间层表征知识提取方法利用到更多视觉领域下游任务的蒸馏中,并在多个任务上验证本文所提出方法的压缩效果.

猜你喜欢

中间层算子准确率
与由分数阶Laplace算子生成的热半群相关的微分变换算子的有界性
Zn-15Al-xLa中间层对镁/钢接触反应钎焊接头性能影响
乳腺超声检查诊断乳腺肿瘤的特异度及准确率分析
不同序列磁共振成像诊断脊柱损伤的临床准确率比较探讨
2015—2017 年宁夏各天气预报参考产品质量检验分析
颈椎病患者使用X线平片和CT影像诊断的临床准确率比照观察
Domestication or Foreignization:A Cultural Choice
QK空间上的叠加算子
如何利用合作教学激发“中间层”的活跃
浅谈通信综合营帐数据中间层设计原则与处理流程