FCAT⁃FL:基于Non⁃IID 数据的高效联邦学习算法
2022-07-26陈飞扬张一迪
陈飞扬,周 晖,张一迪
(南通大学 信息科学与技术学院,江苏 南通 226019)
随着5G 技术发展,越来越多设备接入网络,将产生的大量数据直接上传至云端。 采集数据集中训练的云计算技术由于存在隐私泄漏、能量消耗大、带宽压力大、传输延时等问题,在移动边缘网络中逐渐式微[1-2]。 随着边缘设备计算能力不断提高,在设备本地直接训练模型成为可能。 为解决边缘设备协同训练问题,文献[3]首次提出联邦学习(Federated Learning, FL)概念,在分布式FL 框架下,边缘设备无需向服务器上传本地数据,利用自身数据就地训练,将训练完成的模型参数上传至服务器进行聚合,聚合参数再次下发给边缘设备继续训练,不断重复,直至效果良好。 然而,传统FL 算法(如联邦平均算法FedAvg) 在客户端数据非独立同分布(Non⁃independent and Identically Distributed, Non⁃IID)情况下效果不佳。 文献[4]发现当数据Non⁃IID 时,FedAvg 在CIFAR⁃10 数据集上的准确率下降55%,使用Earth's Movers' Distance(EMD)解释了下降原因,并通过创建全局共享数据子集缩小本地模型之间差距,提高Non⁃IID 数据的训练效果;文献[5-7]分别使用懒惰节点判别法、CSFedAvg 算法和深度强化算法仔细选择客户端以求减小Non⁃IID 数据对聚合的影响;文献[8]通过层次分析法衡量客户端Non⁃IID 数据的质量,依据数据质量为客户端分配权重;文献[9]提出可自适应调整的FedProx 算法,通过修改客户端本地损失函数,限制各客户端局部模型之间差距,提高网络模型在Non⁃IID 数据上的性能;文献[10]借鉴AdaBoost 算法的思路,根据上一轮全局迭代中客户端损失值动态调整当前迭代中每个客户端本地训练轮次,从而加快Non⁃IID 数据下FL 的收敛速度;文献[11]将FL 从同步层面扩展到异步层面,改变传统聚合方式,提出一种指数滑动加权平均方法FedAsync,服务器在接收到某个客户端发来的模型参数后,将其与全局模型参数进行加权后发送给该客户端,以求形成一个全局泛化能力更强的模型;文献[12]提出一种将模型减枝与FL相结合的方法,目的在于减小模型的尺寸,加快本地模型收敛速度;文献[13]认为Non⁃IID 数据下传统FL 可能会偏向某些设备,提出q⁃FFL 策略为各客户端更新本地模型参数,以求提高FL 性能;文献[14]引入α⁃fair 策略,在尽量保证聚合模型性能前提下提高客户端间准确率分布的公平性;文献[15]提出FedFa 聚合算法,在每轮全局迭代时以客户端训练准确率和被选中次数为依据,为各客户端分配聚合权重,从而提高Non⁃IID 数据下聚合模型的准确率和公平性。
为解决客户端数据Non⁃IID 带来的收敛速度慢、公平性差等问题,本文首先定量分析收敛速度慢的原因。 然后以分析结果为根据,按照客户端本地模型参数和服务器聚合模型参数之间关系,提出两种不同的聚合策略,服务器依据策略在每次全局迭代时为不同客户端分配可变的自适应权重以求减少全局迭代轮次,并且在客户端引入个性化迁移学习(Transfer Learning, TL)模型和动量梯度下降算法(Momentum Gradient Descent, MGD)以求提高客户端本地模型性能,从服务器和客户端两方面共同加快Non⁃IID 数据下FL 的收敛速度,减小客户端通信和计算开销。 最后,使用合适的评价标准说明FCAT⁃FL 对提高客户端间公平性和准确性也有良好效果。
1 问题描述
FL 网络架构如图1 所示,M个客户端的数据集为{C1,C2,…,CM}, 对应的数据量为{| C1|,|C2|,…,|CM |}。 在多分类问题下,数据样本{Χ,Y} 由特征空间Χ和标签空间Y={1,2,3,…,C} 构成,C为类别标号。 设每轮通信有N个客户端参与,则服务器聚合公式、客户端模型更新公式、FL优化目标[16]分别为
图1 联邦学习网络框架
式中,predi(x,wt) 为预测函数,表示样本x预测为第i类(i <C) 的概率,dk(y=i) 表示客户端k上第i类样本的数据分布。
假设第t次聚合后,服务器发送全局模型参数wt给各客户端,本地客户端m和n接收到服务器下发的模型参数后在本地更新,由式(2)、(4)可得两者间的模型参数差异为
式(5)揭示了客户端模型参数差异与客户端数据分布差异间的关系。 在FedAvg 中,当客户端数据IID 时,客户端数据分布差异很小,因此,客户端间模型参数差异也很小,即使分布训练后再根据数据量占比聚合模型,其全局模型也能逼近但略差于集中式训练的模型;但当数据Non⁃IID 时,若某客户端数据分布高度倾斜,客户端间模型参数可能偏离甚至相反,这导致有的本地模型对聚合产生积极影响,将之聚合有利于全局模型性能改善;有的则对聚合产生消极影响,将之聚合会损害全局模型性能,而且随着全局迭代轮次增加,不同客户端模型参数差异会逐渐累积。 这样的发现使得依据客户端数据量比重进行聚合的想法变得不再合理。 一个合理聚合方式应为:给有利客户端较大聚合权重,给有害客户端较小聚合权重。 这种聚合思想的提出为2.1 节中证明聚合策略对FL 收敛速度的影响及如何改进聚合策略提供了指导性建议。
2 基于Non⁃IID 数据的FCAT⁃FL 算法
在数据Non⁃IID 下,为提高FL 收敛速度,从服务器和客户端两方面入手。 在服务器端,结合第1 节的建议,改进聚合策略并从理论层面说明其对减少全局迭代轮次的有效性;在客户端,使用TL 模型和MGD算法,保证本地模型个性化的同时减少客户端需训练的模型参数数量,加快本地模型训练速度。 服务器和客户端两方面的共同改进加快了FL 收敛速度。
2.1 聚合策略的改进
前文分析出在客户端数据Non⁃IID 时,有害客户端影响聚合模型性能,导致需要更多轮全局迭代才能达到收敛,轮次增加意味着通信开销或成本增加,这在客户端资源有限的移动边缘网络中是不可取的。 本节从聚合层面定量分析影响FL 收敛速度的因素,依据前文的建议,为FL 的快速收敛提出一种有效的聚合方式。
图2 策略流程图
2.2 TL 和MGD
对于分布式FL 而言,减少通信开销比减少计算开销更重要。 为此,在客户端引入TL 模型和MGD 算法。 TL 的使用减少客户端需训练和上传的参数数量,减小通信开销,同时,在分布式计算中,客户端通常拥有较少的训练资源,随着模型复杂度的增加,使用较少数据训练的本地模型往往性能不佳,TL 的使用可以帮助客户端在数据较少的情况下更好地提取特征进行训练;MGD 的使用加快本地模型训练速度,减小计算开销。 从而,使FCAT⁃FL 更适用于用户资源有限的移动边缘网络。
TL 主要研究迁移任务A 上学习的知识至任务B,使任务B 用更少代价去训练模型并获得不错的性能[20]。 如图3 所示,对于相似的任务A 和B,如果它们的特征提取方法相近,则前面数个特征提取层(卷积层)可以重用,后面的分类子网络可以根据具体任务设定从零开始训练。 卷积层用于提取样本特征,靠前的卷积层提取低级特征,越靠后的卷积层的抽象提取能力越强,分类子网络由全连接层构成,能够提取样本个性特征,是TL 模型的个性化层。
图3 神经网络迁移学习
使用随机梯度下降算法(Stochastic Gradient Descent,SGD)更新客户端本地模型可能产生左右振荡的情况。 文献[21]证明了带动量的分布式随机梯度下降算法的快速收敛性。 因此,为加快客户端本地模型收敛速率,客户端TL 模型在更新时综合考虑当前损失函数梯度值和之前的梯度值,得到如式(11)所示的加权动量项,然后利用式(12)进行本地迭代以求加快本地模型收敛速度。
对于复杂图像的迁移学习分类问题,可以使用PyTorch 官方提供的基于ImagNet 数据集已预训练完成的VGG16、ResNet18、ResNet50 等模型进行特征提取迁移学习操作或微调迁移学习操作[22-23],并加以个性化的全连接层。 其中,特征提取迁移学习操作冻结预训练模型的特征提取层参数,仅训练全连接层,这样可以减少客户端需训练的模型参数数量,加快客户端训练速度,并且可以获得不错的准确率;微调迁移学习操作使用预训练网络初始化目标网络,而非随机初始化,由于重用部分网络已经学习到良好的参数状态,通过简单的微调,网络模型可以快速收敛到较好的性能。
3 公平性与准确性评价
联邦学习的最终目标是最小化全局损失函数,获得一个泛化性能最好的聚合模型,而以客户端数据量占比权衡贡献度的传统聚合策略在数据Non⁃IID 情况下会使全局模型偏向于某些客户端,使客户端间准确率分布变得不公平。 2.1 节提出的策略改变传统聚合方式,在每次全局迭代时依据客户端本地模型参数和服务器聚合模型参数间的关系,动态地为不同客户端分配自适应权重。 前文已证明该动态策略能够减少全局迭代轮次,是否对提高客户端间公平性和准确性也有积极影响?
为评价2.1 节的动态聚合策略对客户端间公平性和准确性的影响,参照文献[24-25],引入Jane's index 值评价客户端间准确率分布的公平性,但Jane's index 本质上是以方差的形式评价客户端间准确率分布是否公平,不能对准确率本身做出评价,为体现客户端准确率水平,在Jane's index 公式基础上做出改进,如式(13)所示。
4 算法流程
5 实验
5.1 数据分配
使用PyTorch 框架在MNIST 数据集上模拟FL 过程,该数据集有60 000 张训练样本和10 000 张测试样本,训练样本分配给客户端用于训练,测试样本留在服务器端,用于计算全局模型的准确率和损失值。 实验设置15 个客户端和1 个服务器,基于IID 数据和Non⁃IID 数据做FCAT⁃FL 和其他算法的对比实验,因此,需为每个客户端分配IID 数据和Non⁃IID 数据:将训练样本随机洗牌打乱,按图4 为客户端分配数据量相同的IID 数据;将训练样本以标签大小排序,按图5 为客户端分配数据量相同的Non⁃IID 数据。
图4 IID 数据分配
图5 Non⁃IID 数据分配
5.2 实验设置与结果分析
5.2.1 客户端实验
为说明TL 与MGD 在减少模型参数数量和加快客户端训练速度方面的有效性,模拟分布式框架下客户端拥有较少数据的真实情况,令某一客户端仅有600 个MNIST 训练数据和200 个测试数据,然后在该客户端上设计3 个对比实验。 实验1(TL⁃CNN):使用基于复杂数据集CIFAR⁃10 预训练的拥有两层特征提取层(卷积层)和一层全连接层的网络模型(该网络已事先训练好),保留特征提取层参数,并将其迁移至客户端上基于500 个MNIST 训练数据的相同模型中,追加合适的全连接层,冻结除全连接层外的所有网络参数,仅训练全连接层参数,同时,使用MGD 算法更新客户端本地模型。 实验2(CNN):客户端使用两层卷积层和一层全连接层网络模型进行训练,同时,使用SGD 算法更新客户端模型。 实验3(BP):在客户端构建三层BP 网络模型,使用SGD 算法更新本地模型。 3 个对比实验的准确率和损失函数如图6 所示,需训练的模型参数数量如表1 所示。
实验使用简单的数据集进行图像分类,因而即使神经网络结构很简单,客户端的训练数据量很少,也能获得如图6 所示的较好效果。 从图6 可看出,TL⁃CNN 模型的准确率和损失值要好于CNN 和BP模型;从表1 可知,TL⁃CNN 需训练的模型参数数量比CNN 少10.98%,比BP 少29.54%。 由TL⁃CNN和CNN 的对比可以看出,在网络结构相同和客户端数据量较少的情况下,客户端使用TL 模型和MGD算法可以提高模型性能,加快本地模型训练速度,减少需训练及上传的模型参数数量。
表1 模型参数数量表
图6 对比实验效果图
本节实验对比3 种模型最终的参数数量,没有考虑TL 前期预训练的成本,但是本节实验是为验证在FL 框架下运用TL 的好处。 在实际应用中,针对复杂图像分类问题,可以使用PyTorch 官方基于ImageNet 数据集已预训练完成的VGG16、Resnet50等网络或相似任务已训练完成的网络模型,拷贝其特征提取层参数并加上合适的全连接层,进行特征提取迁移学习或微调迁移学习,而无需进行预训练,从而有效解决前期预训练成本问题。
5.2.2 服务器聚合实验
在客户端使用TL 和MGD 的前提下,当所有客户端的数据为IID 和部分客户端的数据为IID 时(每个客户端有600 训练样本,200 测试样本),用2.1 节的聚合策略1、策略2 与平均策略[3]、FedAsync 策略[11]、q⁃FFL 策略[13]、FedFa 策略[15]这4 种基线聚合策略进行对比,同时,为证明实验结果的有效性,使用AUC 值展示分类模型的优劣。
每轮聚合后服务器在10 000 张测试样本上计算当前聚合模型准确率、损失值,对比结果分别如图7、8、9、10 所示,分类模型训练完成后的AUC 值如表2 所示。 由表2 可见,6 种策略的AUC 值随着客户端数据Non⁃IID 程度的增加而减少,但AUC 值均在0.96 以上,证明6 种策略的分类模型效果优良。从图10 可看出,当所有客户端的数据为IID 时,6 种策略聚合模型的准确率、损失值、收敛速度基本相同;从图7、8、9 可看出,当部分客户端的数据Non⁃IID 时,策略1 和策略2 的优势逐渐体现。 其中,当1/3 或1/2 客户端的数据不满足IID 时,策略1 和策略2 聚合模型的最终准确率和损失值略微好于4 种基线算法,但相差不大,全局收敛速度则快于4 种基线算法;当3/4 客户端的数据不满足IID 时,FedFa策略聚合模型前期效果较差,策略2 聚合模型的准确率和损失值后期出现波动,而策略1 聚合模型的准确率、损失值、收敛速度一直好于4 种基线算法,且后期也很稳定。 总体看来,6 种策略的对比结果说明本文提出的策略1 在部分客户端的数据Non⁃IID 时能有效提高FL 全局收敛速度且鲁棒性更好,证明根据客户端模型参数与全局模型参数间关系权衡贡献度的想法可行且效果优良。
表2 6 种策略不同情况下的AUC 值
图7 3/4 客户端的数据为Non⁃IID
图8 1/2 客户端的数据为Non⁃IID
图9 1/3 客户端的数据为Non⁃IID
图10 所有客户端的数据为IID
5.2.3 公平性和准确性实验
为验证2.1 节提出的策略1 能提高客户端间公平性、准确性,每轮全局聚合完成的模型下发给本地客户端计算各客户端测试样本准确率Acci,并根据式(13)计算J(Acc) 值。 在客户端数据不同的Non⁃IID 程度下,6 种策略的J(Acc) 值如图11 所示。 从图11 可以看出,随着客户端数据Non⁃IID 程度增加,不同策略的效果差异开始体现。 其中,当3/4 客户端的数据Non⁃IID 时,其余5 种策略对提高客户端间公平性和准确性的效果皆比平均策略要好,且策略1、FedAsync 策略、q⁃FFL 策略三者效果更胜一筹;策略2 的J(Acc) 曲线后期波动明显,鲁棒性较差;FedFa 策略虽然训练前期效果不好,但后期效果能与策略1、FedAsync 策略、q⁃FFL 策略相媲美。 实验证明本文提出的策略1 对提高客户端间公平性与准确性有良好效果且鲁棒性较好。
图11 0,1/3,1/2,3/4 客户端的数据Non⁃IID
6 结束语
在FCAT⁃FL 中,服务器依据局部模型参数和聚合模型参数之间的关系,使用两种不同的策略量化每个客户端的贡献度后进行全局聚合,从理论和实验两方面证明本文所提的两种聚合策略相比其他基线聚合策略,能有效减少客户端数据Non⁃IID 下的全局迭代轮次,且策略1 的鲁棒性更好。 在客户端使用TL 模型和MGD 算法,帮助拥有较少训练数据的客户端更好地训练本地模型,并且减少客户端需训练和上传的模型参数数量,加快客户端本地模型训练速度。 客户端和服务器两方面的共同改进加快FL 收敛速度,减少客户端通信和计算开销,提高FL性能。 使用改进的Jane's index 证明了相比于其他基线聚合策略,本文提出的策略1 能有效提高数据Non⁃IID 下客户端间公平性与准确性。 可见,FCAT⁃FL 能有效提高客户端数据Non⁃IID 下FL 的收敛速度、公平性、准确性,为在客户端资源有限和数据Non⁃IID 的移动边缘网络中执行FL 提供一种有效的解决方案。