APP下载

基于注意力增强元学习网络的个性化联邦学习方法

2024-01-12高雨佳王鹏飞马华东

计算机研究与发展 2024年1期
关键词:联邦客户端准确率

高雨佳 王鹏飞 刘 亮 马华东

1 (北京邮电大学计算机学院(国家示范性软件学院) 北京 100876)

2 (智能通信软件与多媒体北京市重点实验室(北京邮电大学) 北京 100876)

3 (北京邮电大学人工智能学院 北京 100876)

(gaoyujia@bupt.edu.cn)

联邦学习[1]作为一种分布式机器学习框架,客户端可以在不向服务器传输数据的情况下进行全局模型训练,它解决了数据分散和数据隐私的问题. 与单独工作的个体相比,联邦学习通过客户端和服务器的协作训练,获得了更好的机器学习性能. 同时,可控制的传输间隔降低了通信负载,对于时间敏感的任务,如医疗监测[2]和流量预测[3],联邦学习可以显著减少训练延迟. 这种方法可以在具有相似数据特征和分布的客户群体中很好地工作. 但是,在许多应用程序场景中,客户端的数据集在分布、数量和概念[4]上可能不同,因此有一些限制. 这使得粗糙的全局协作在不考虑单个客户端数据属性的情况下无法获得良好的性能. 因此,在全局协同训练的基础上,为每个客户并行训练个性化模型就变得十分必要.

这种将个性化融入到联邦学习的方式被称为个性化联邦学习[5]. 常用的方法是取一个统一初始值(如全局模型),在本地训练过程中对其进行微调,以实现对客户端数据特征的适应性[6-7]. 由于通常情况下,一个模型不能适用于所有客户端,这种方式缺乏灵活性,难以解决客户的异质性问题. 同时,全局模型的粗融合也会影响全局模型的性能. 为此,我们希望探索一种更优的方式 ,为每个客户端提供不同的协作关系. 具体来讲,促进具有相似特征和数据分布的客户端之间的协作,并疏远具有广泛不同特征的客户端之间的协作. 因此,对于给定的客户端,如何自动提取与其类似的客户端,是一个具有挑战的问题. 此外,对于服务器-客户端框架,需要设计合理的协作训练方法.

为了解决以上问题,本文提出了一种个性化联邦学习框架FedAMN,如图1 所示. 和传统联邦学习架构一样,FedAMN 是由1 个服务器和多个客户端组成的. 不同在于,FedAMN 在每个客户端均分布部署了一个注意力增强元学习网络(attention-enhanced meta-learning network,AMN),来实现客户端间协作关系的生成.AMN 以客户端的模型参数作为输入特征,学习一个元模型来分析不同客户端的相似程度,它可以自动筛选对目标客户端有益的模型参数信息,从而提升客户端个性化模型的性能. 这种方法利用客户端的本地数据进行元模型和本地个性化模型的训练,解决了元模型缺乏监督信号而难以训练的问题. 并且因为数据不需要上传至服务器,客户端的数据隐私可以得到保护.

Fig.1 Framework illustration of FedAMN图1 FedAMN 框架示意图

具体来讲,给定一个目标客户端,AMN 通过注意力机制,利用其他客户端的本地模型参数分析之间的相似性,训练元模型. 用元模型计算的注意力分数构建一个新的注意力增强模型,实现客户端个性与共享的平衡. 通过这样的设计,可以增强训练的灵活性,满足对客户任务精准决策的需求. 同时,为了减少AMN 带来的额外的计算量增加,我们设计了一种渐进式的训练模式. 在开始的通信轮次中,仅对客户端的本地模型进行训练,当全局模型准确率达到稳定后,AMN 再加入训练. 在原本的客户模型基础上,进一步提升个性化模型的效果. 并且,考虑到客户端需要同时训练AMN 和客户端的本地基础网络,提出了一种交替训练的策略,以端到端的方式对2个网络进行训练,最终实现客户端的个性化注意力增强模型的生成.

为了证明FedAMN 的有效性,我们在2 个基准数据集和8 种基准方法上进行了大量实验. 与现有最先进的个性化联邦学习方法相比,我们的方法分别提高了至少3.39%和2.45%的模型性能. 此外,为了证明AMN 中注意力机制的有效性,我们对AMN 的双层网络结构进行了消融实验,并在真实世界数据集上,对AMN 各层的注意力分数进行了可视化分析实验.最后,设计了不同数据分布和数据特征的客户端中的对比实验,探究客户端异质性对训练的影响.

总的来说,本文的贡献有3 方面:

1) 将个性化联邦学习中的协作关系生成问题形式化为一个分布式元学习网络训练任务,并设计了一种注意力增强元学习网络(AMN)来解决客户端的本地模型的个性化问题;

2) 提出了FedAMN 框架,采用渐进式的训练模式,缓解AMN 引入带来的计算量增加问题,以交替训练的策略实现客户端个性化注意力增强模型的生成;

3) 对不同推理任务进行大量实验,证明我们的方法显著优于现有最先进的个性化联邦学习基线方法,并可以应用于许多具有隐私性的分布式数据个性化建模场景.

1 相关工作

本文的工作受到个性化联邦学习和元学习的思想启发,下面我们对这2 个领域的现有工作进行简要回顾.

1.1 个性化联邦学习

一般来说,现有的个性化联邦学习方法[5]可以分为3 大类,分别是本地微调方法、模型正则化方法和多任务学习方法.

在本地微调方法中,每个客户端接收一个全局模型,并使用其本地数据和多个梯度下降步骤对其进行调优. 例如,PMF[8]为用户本地模型的较高层设计了2 个个人适配器(个人偏差、个人过滤器),可以用客户信息对全局模型进行微调.FedPer[9]提出了一种基础+个性化层的方法,只对基础层进行协同训练.在模型正则化方法中,文献[10]在本地模型和全局模型的距离上增加了一个正则项,并使用一个混合参数来控制它们之间的加权比例. 文献[11]提出了一种知识精馏的方法来实现个性化,将正则化应用于本地模型和全局模型之间的预测. 这些本地微调方法的局限性在于,它们使用统一的全局模型进行个性化设置,不能为异构数据客户端上的大量潜在任务提供灵活的个性化建模. 文献[12]提出了一种分层异构联邦学习方法将全局模型拆分为适配不同客户端资源的子模型序列,提升建模效果. 联邦多任务学习方法[13]将每个客户机的优化视为一个新任务,它解决了通信约束、离散和容错问题,这些问题主要集中在凸模型上. 但是由于该方法对强对偶性的刚性要求,当客户采用非凸深度学习模型时,该方法就不再保证适用. 文献[14]提出使用注意力函数来测量模型参数之间的差异,它在服务器上为每个客户端维护了一个个性化云模型,通过传递个性化云模型实现信息的传递. 尽管注意力函数对客户端之间的成对协作进行了建模,但它是使用模型参数作为评估标准,难以对客户端之间的动态协作关系进行识别,在当客户端模型异构不稳定时,难以对个性化云模型进行准确地融合.

1.2 元学习

元学习是近年来研究的热点. 由于它能够在有限的训练数据下很好地泛化,因此被广泛应用于少样本学习[15]、强化学习[16]、迁移学习[17]等领域. 例如,文献[18]引入了一种记忆增强神经网络,用于在没有灾难性干扰的情况下利用新数据重新学习模型.文献[19]使用了一个元损失函数来存储不同环境的信息. 文献[20]提出了一个元翻译模型,该模型可以快速适应具有有限训练样本的新领域翻译任务. 最近,元学习也被用于个性化联邦学习. 文献[21-22]从经验的角度研究了模型不可知论元学习(MAML)类型方法与联邦学习的不同组合. 找到一个初始的全局模型,使当前或新的客户端能够通过执行1 个或几个步骤的梯度下降轻松地适应本地数据集. 文献[23]把元学习看作是对一系列损失的在线学习,每个损失函数都会得到单个任务中的上界. 文献[21-23]所提的方法可以被认为属于1.1 节中提到的本地微调方法.

与现有的工作不同,在本文中我们设计了AMN,它以细粒度的视角,利用基于注意力机制的神经网络来学习客户端之间的协作关系,并利用客户端的本地数据作为监督信号来训练网络. 该方法可以为客户端实现自适应的协作关系生成,并且被证明在客户端数据是异构的情况下特别有效.

2 问题定义

为了方便阐述,首先给出问题定义. 在一个个性化联邦学习系统中,有n个客户端连接到服务器. 每个客户端具有相同类型的基础网络f(·),客户端的模型参数集合W={w1,w2,…,wn}表示,其中wi∈R1×d为客户端i的模型参数. 服务器收集客户端上传的模型参数,通过聚合算法维护一个全局模型wg∈R1×d. 通过传输wg和W实现客户端的协同训练. 对于每个客户端,我们定义目标函数为F(wi)=L(f(wi);Di),其中Di表示客户端i的本地数据集,L(·)是损失函数.

由于客户端的数据集是异构的,我们的目标是在不公开客户端数据集的情况下,分析客户端本地模型参数之间的相似性,并构建细粒度的协作关系,以提升f(wi)的性能. 本文提出的AMN 可以被定义为GΘ(wg,W),其中 Θ为元模型参数. 此时,个性化联邦学习的优化目标定义为

其中 λ为正则化系数. 优化目标可以利用元网络和客户端的基础网络共同调节wi,考虑到大多数深度学习网络的复杂性,很难得到W和 Θ的最优解. 我们采用梯度下降技术和交替训练策略来解决这个双层优化问题.

3 基于AMN 的个性化联邦学习方法设计

3.1 FedAMN 框架概述

图1 展示了FedAMN 的整体框架设计,它由1 个服务器和多个客户端组成. 客户端的数据由客户自己收集,通常为边缘传感器收集,具有客户端的个性化特征. 为了保护客户隐私,服务器与客户端仅通过上传本地模型参数wi和下载全局模型参数wg以及参数集合W实现协同训练,并不会对本地数据进行传输.

服务器的主要功能为:

1) 收集客户端的模型参数,构成客户模型参数集合W;

2) 使用聚类算法(如k-means[24]或mean-shift[25]算法)将参数集合划分为k个簇;

3)根据全局聚合算法计算全局模型参数wg.

在功能2)中,聚类算法的意义在于降低引入A MN 而带来的额外数据传输量,同时最大限度地减少模型参数传输造成的潜在隐私泄露风险. 该操作对客户进行了粗分类,为之后元学习网络的自适应的协作关系选择提供基础. 在实际操作中,k值由网络传输量限制要求决定,理想情况为k=1(即全部客户端在同一个簇中,不对模型参数集进行划分). 客户端只会下载包含自己模型参数的簇Wk⊆W. 当任务对网络传输量的限制较高时,可以通过增大k的取值降低单个通信轮次中的传输量.k的取值越大,每个客户端接收到的模型参数簇就越小,单个通信轮次的传输量就越小.k值的选择标准为选取传输限制内可以实现训练的最小值. 为了简化描述,后文我们均以k=1 为例对FedAMN 进行说明,并以W表示客户端接收到的模型参数簇.

在功能3)中,全局聚合算法是指现有的联邦学习模型聚合算法,如FedAvg[1]及一些变体. 通过引入全局模型,AMN 可以引入全局信息作为个性化模型的补充,并为客户端的模型训练提供一个软启动.

客户端的主要功能为:

1) 收集客户数据,并对数据进行预处理,构建本地数据集;

2) 从服务器下载客户端的模型参数集合W(或聚类后的客户模型参数集合子集Wk)和全局模型参数wg;

3) 利用本地数据集训练客户端的基础网络;

4) 交替训练基础网络和元网络;

5) 上传本地模型参数.

由于AMN 的训练会引入额外的计算量,为了快速启动,我们采用渐进式的训练策略. 在训练初期,仅使用全局模型实现客户端间的协同训练,当客户端的本地模型达到稳定或达到设置好的训练阈值p时,AMN 再加入训练. 训练阈值p的设定是为了防止本地模型因数据量较少等原因导致不收敛时,依然可以引入AMN 的另一个限制条件.p的取值可以根据任务实际情况设定,设置为本地模型得到充分训练的最小通信轮次.

下面我们对AMN 的具体结构进行描述,再介绍FedAMN 的训练流程.

3.2 元学习网络结构

我们发现,现有的联邦学习方法通常利用全局模型作为信息传输的载体,但在客户端数据异构的情况下,单一的全局模型难以解决客户的异质性问题. 同时,模型的粗融合也会影响全局模型的性能.而在模型训练过程中,具有相似特征和数据分布的客户端通常会相互提供更多的有益信息. 因此,我们期望提出一个元模型,通过简单的训练就可以自适应地发现相似客户端间的协作关系,以促进模型训练. 同时减少具有不同特征的其他客户端对模型训练的负面影响. 基于以上需求,我们设计了一种分布部署在每个客户端中的元学习网络,实现客户端间协作关系的自动生成. 该网络被称为AMN.

AMN 以客户端的基础网络模型参数作为输入特征,代表客户数据特性. 并根据参数集合学习一个元模型来分析不同客户端的相似性. 为了实现模型参数的自动融合,受到文献[26]的启发,我们使用注意力机制实现客户端模型权重的自动分配. 该机制的基本思想是查询q和键值矩阵K之间的相关性,可以被表示为

其中d是比例因子,避免公式內积的值过大.

具体来讲,AMN 的网络结构如图2 所示,第1 层以客户端模型参数集合W为键值,客户端的本地模型wi为查询,利用注意力机制自动分析目标客户端与其他客户端的相似性,并使用加权分数对W进行聚合. 第1 层可以过滤掉对目标客户端没有帮助的其他客户端的模型参数. 第2 层以第1 层的输出、全局模型为键值,以wi为查询,从而对它们进行融合,实现客户个性和共性的平衡.

Fig.2 Structure illustration of AMN network图2 AMN 网络结构示意图

因此,AMN 的优化过程可以表示为

3.3 FedAMN 训练流程

在FedAMN 中,我们引入额外的元学习网络AMN 进行个性化联邦学习,因此在每个通信轮次中,客户端需要共同训练基础网络f(·)和元学习网络G(·). 这2 个网络之间具有复杂的依赖关系,具体来讲,元学习网络的输出会成为基础网络的模型参数,基础网络的参数又会作为元学习网络的输入. 为了更好地训练这2 个网络,FedAMN 采用渐进式的训练模式,缓解AMN 引入带来的计算量增加问题. 以交替训练的策略实现客户端个性化注意力增强模型的生成.

FedAMN 需要客户端和服务器配合完成协作训练,算法1 详细说明了在服务器进行的模型的收集与聚合过程. 算法2 描述了在客户端进行的基础网络和元学习网络的交替训练策略. 客户端和服务器通过传输模型参数实现信息交换. 下面,我们对客户端中进行的渐进式交替训练模式进行介绍.

算法1.模型的收集与聚合(服务器).

输入:客户端的本地模型参数集合W={w1,w2,…,wn},聚类参数k;

输出:全局模型参数wg.

① 初始化wg;

② 发送wg至各客户端;

③ fort=1, 2, … do

④ 等待接收客户端的模型参数集合

W={w1,w2,…,wn};

⑤ 对W进行聚类,得到k个簇

W1,W2,…,Wk;

⑥ 根据全局模型聚合算法计算全局模型参数wg;

⑦ 将Wj∈{W1,W2,…,Wk}和wg发送给对应客户端;

⑧ if 达到停止训练条件

⑨ break;

⑩ end if

⑪t=t+1;

⑫ end for

⑬ returnwg.

算法2.基础网络和元学习网络的交替训练策略(客户端).

输入:全局模型参数wg,对应的模型参数集合簇Wj,渐进式训练阈值p;

输出:客户端本地模型参数wi.

①初始化wi和 Θi;

② fort=1, 2, … do

③ 从服务器下载Wj和wg;

④ forepoch=1, 2, … do

⑤ ift<pthen

⑯ end for

⑰ end if

⑱ if 达到停止训练条件

⑲ break;

⑳ end if

㉑t=t+1;

㉒ 上传wi至服务器;

㉓ end for

㉔ end for

㉕ returnwi.

在开始的通信轮次中,客户端只负责训练基础网络. 该方式与通常的联邦学习训练方式相同,在这里不做专门描述. 当本地模型达到稳定或设定的阈值后,AMN 与基础网络交替训练.图3 为客户端i中AMN 的训练方式. 首先,在AMN 网络训练阶段,为了调整其可学习参数 Θi,我们需要将其输出嵌入到基础网络中,作为基础网络的模型参数. 固定该参数,利用客户端的本地数据集计算梯度,并回传梯度对AMN 参数进行调整. 该过程可以表示为

Fig.3 Training method of meta-learning network AMN in client i图3 客户端i 中元学习网络AMN 的训练方式

算法2 中用Wj表示模型参数集合簇. 如3.1 节中服务器的功能2)的描述,在后文我们统一用W表示该概念,即k=1 时的状态. 之后,利用训练好的GΘi(wg,W)得到注意力增强模型参数watt. 以watt作为基础网络的初始值,对基础网络进行训练,得到最终的客户端的个性化模型参数wi,该过程可以表示为

此时的wi即为该通信轮次中的客户端的个性化模型.

上文的描述中仅对单个客户端的AMN 训练方式进行了说明. 若考虑全部客户端,我们可以得到FedAMN 的全局优化目标为

为了对H(W,Θ)进行优化,以一个通信轮次为例,FedAMN 中客户端的训练流程为:

1) 从服务器下载W和wi.

2) 当通信轮次小于训练阈值p时,用F(wi)训练客户端的基础网络,以获得稳定的本地模型,梯度下降过程表示为

最后,对更新后的本地模型进行微调,得到最终的个性化模型

3) 当通信轮次大于训练阈值后,使用AMN 网络的输出更新客户端本地模型

并对更新后的本地模型进行微调,得到最终的个性化模型

当通信轮次大于设定值或客户模型收敛时循环停止,此时客户端的本地模型作为最终的客户个性化模型. 该过程与算法2 中的描述相对应.

在FedAMN 中,客户端承担了更多的计算量,服务器仅承担模型的聚合、存储、传输等功能. 这种情况下,FedAMN 允许客户在不泄露自己数据的情况下获取更多有益信息,减少了由于本地数据量小而导致的模型过拟合问题,并最大程度地消除了客户数据异质性带来的模型负迁移问题. 因此,每个客户端的模型准确率可以明显高于传统联邦学习方法以及单机训练方法.

4 实验结果与分析

4.1 数据集介绍及实验设置

为了验证FedAMN 的普适性和有效性,本文实验数据集采用联邦学习中普遍采用的基准数据集VMNIST 和真实世界中的空气质量数据集W&A China.

V-MNIST 数据集由6 种常用的手写字体识别数据集组成,分别是MNIST[27]、MNIST 加旋转、MNIST加噪声背景、MNIST 加图像背景、MNIST 加旋转和图像背景、Fashion MNIST[28]. 每个数据集都是10 分类任务. 为了模拟联邦学习场景下客户端的个性化数据,我们将这6 个数据集采用独立同分布的随机抽样方法,每个数据集按照原始数据量等比分布分为3 个子集. 6 个数据集中,共72 795 个样本被划分为18 个子数据集,每个数据集由同一个数据的子集组成,用于模拟18 个客户端的本地数据集.

W&A China 数据集由中国国家环境监测中心和国家气象中心发布①http://www.cnemc.cn,https://www.ncdc.noaa.gov,包含2017 年1 月1 日至2017 年12 月31 日在全国4 个直辖市(北京、天津、上海、重庆)中的42 个监测站点按小时采集的共23 万条环境数据. 每条数据包含13 种特征元素,分别为温度、压力、湿度、风向、风速、站号、收集时间和6 种主要空气质量污染物的浓度. 其中根据中国环境空气质量标准(AAQS)[29],可以按照PM 2.5 的浓度将空气质量分为5 个等级. 我们采用的任务为使用过去48 h的历史数据对未来的空气质量进行预测. 由于城市和周边环境的差异性,每个监测站点的数据分布是异构的,参考真实世界的设置,在实验中,每个站点都被认为是独立,用于模拟42 个客户端的本地数据集.

不同的基础网络会显著影响FedAMN 以及基线方法的性能,我们对2 个数据集采用了2 种不同的经典基础网络. 对于V-MNIST,我们用AlexNet[30]进行图像分类任务建模,AlexNet 是由5 个卷积层和3 个全连接层组成的卷积神经网络. 优化器为SGD,学习率设置为0.02,使用多标签交叉熵损失函数. 每个通信轮次中的epoch设置为1,批样本数量设置为32.对于W&A China 数据集,我们构建了一个2 层的门控循环单元网络(GRU)[31],每层128 个隐藏单元作为基础网络,用于时间序列推理任务. 将结果通过全连通层输出,并与真值进行比较,得到预测精度. 优化器为SGD,初始学习率设置为0.01,使用多标签交叉熵损失函数. 每个通信轮次中的epoch设置为2,批样本数量设置为32.

本文采用分类准确率和分类精确率作为对比各种方法的评价指标. 由于本文研究的是客户端模型的建模效果,因此使用客户端的本地测试集进行测试,最终结果为全部参与训练的客户端的测试结果的平均值. 对于联邦学习和个性化联邦学习方法,我们分别对V-MNIST 和W&A China 数据集分别设置最大通信轮次为150 和175.在该轮次内,选取在验证集中表现最佳的个性化模型或全局模型,用于结果的测试.

全部实验均在PyTorch1.5.0 下进行. 运行在一个4 Tesla-P100 卡GPU 服务器中,搭载Ubuntu 16.04.7 操作系统,CPU 为Intel®Xeon®E5-2620,128 GB 内存. 为了模拟实际情况下远程客户端的传输开销,采用基于socket 包的TCP 协议进行面向连接的可靠传输.

4.2 基线方法的对比实验

在本节中,我们将FedAMN 与8 种基线方法进行比较,分别是客户端单机训练、FedSGD、FedAvg[1]、FedProx[32]、 FedPer[9]、 FedHealth[33]、 联邦迁移学习(FTL)[5]和FedAMP[14]. 其中FedSGD,FedAvg,FedProx是联邦学习方法,FedPer,FedHealth,FTL,FedAMP 是个性化联邦学习方法.

表1 显示了不同基线方法与FedAMN 在相同设置下的分类准确率. 可以看出,当客户端之间没有协作关系时,单机训练可以达到一个基础的分类准确率.2 个数据集代表了不同客户差异性下的任务性能.在V-MNIST 数据集中,基线联邦学习方法除了FedSGD 之外,FedAvg 和FedProx 均带来了分类准确率的提升. 这是由于图像识别任务比较复杂,有限的本地数据不足以提供必要的信息,此时客户间的协作训练非常重要. 而在W&A China 数据集中,3 种联邦学习方法的效果均弱于单机训练. 因为W&A China 是真实世界数据集,各城市间环境差异性大,此时使用联邦学习方法反而会造成性能下降. 并且根据实验结果我们发现,在2 个数据集中,联邦学习方法的性能都显著低于个性化联邦学习. 与FedAMN相比,FedAvg 的平均分类准确率在V-MNIST 和W&A China 数据集中分别低了2.9 个百分点和11.59个百分点. 这是由于客户端数据分布差异性导致的,说明了在联邦学习训练中构建客户个性化模型的必要性.

Table 1 Classification Accuracy of Different Methods表1 不同方法的分类准确率 %

在个性化联邦学习方法中,FedPer,FedHealth,FTL属于本地微调方法,FedAMP 属于多任务学习方法.FedAMP 在基线方法中的性能表现最佳. 相对于FedAMP,FedAMN 在2 个数据集的分类准确率分别提升了2.72 个百分点和2.04 个百分点. 可以看出,与现有方法相比,引入AMN 可以显著降低数据分布差异较大对模型建立带来的负面影响.

表2 显示了不同基线方法在相同设置下的分类精确率. 分类精确率描述了在预测为正的样本中有多少是真正的正样本,体现了模型的查准率. 在VMNIST 数据集中,分类精确率和分类准确率表现相似. 在W&A China 数据集中联邦学习方法的分类精确率出现了大幅度下降. 这是因为V-MNIST 是基准数据集,不同类别样本的数量几乎分布均匀,而W&A China 是真实世界数据集. 通常,在类别不平衡分类问题中,模型的分类精确率更低. 城市中空气质量差的样本明显少于空气质量好的样本,在此种情况下,对于样本较少的类别(如5 级污染)查准率会更低. 尽管如此,FedAMN 相较于其他基线方法依然表现出更好的分类效果. 与最优的联邦学习方法FedProx 相比,FedAMN 的分类精确率在2 个数据集中分别提升了6.49 个百分点和17.56 个百分点. 与最优的个性化联邦学习方法FedAMP 相比,FedAMN 的分类精确率分别提升了2.45 个百分点和1.55 个百分点.

Table 2 Classification Precision Rate of Different Methods表2 不同方法的分类精确率 %

此外,为了验证FedAMN 在数据异构和本地模型异构情况下的训练表现,我们对W&A China 数据集的数据和本地模型结构进行了调整,对比了不同方法的分类准确率.

在数据异构实验中,原本的W&A China 数据集的每条样本数据均包含13 种特征元素,我们对数据集中的每个站点均随机去除其中的2 种特征. 此时,站点的数据特征元素为11 种,各站点的数据特征元素的种类并不相同,以此模拟数据异构的情况. 实验结果如表3 所示,可以看出,在数据异构情况下,联邦学习方法FedSGD,FedAvg,FedProx 相较于单机训练的分类准确率分别下降了20.96 个百分点、14.78个百分点、13.34 个百分点. 这是由于数据异构情况下,全局模型并不能很好地实现信息传递,导致模型效果变差. 个性化联邦学习方法FedPer 和FedHealth的分类准确率也略低于单机训练. 此时,联邦设置并不能带来模型效果上的提升. 个性化联邦学习方法FTL 达到了几乎等同于单机训练的效果. 在基线方法中只有FedAMP 可以达到略高于单机训练的效果.而FedAMN 相较于单机训练,分类准确率提升了4.36 个百分点,相较于表现最优的个性化联邦学习方法FedAMP 的分类准确率提升了3.52 个百分点. 该结果表明FedAMN 在数据异构情况下相较于其他基线方法,可以更好地实现个性化建模.

Table 3 Classification Accuracy of Different Methods in the Heterogeneity Cases on W&A China Dataset表3 不同方法在W&A China 数据集中不同异构情况下的分类准确率 %

在模型异构实验中,我们对不同站点采用不同的本地模型. 具体来讲,不同站点依然采用2 层GRU作为基础网络. 但在北京和天津的站点中,每层GRU包含128 个隐藏单元,上海和重庆的站点每层GRU包含256 个隐藏单元,以此模拟本地模型异构的情况. 对于联邦设置中的模型聚合操作,我们仅对模型中重叠的参数进行加权平均,得出不同基线方法下的分类准确率. 实验结果如表3 所示,可以看出模型异构对于基线方法均有较大影响,尤其是联邦学习方法,分类准确率相较于单机训练至少下降了30.62个百分点. 这是由于不同模型结构下的全局模型难以聚合,对于结构不同的部分无法实现加权平均. 个性化联邦学习由于有模型微调的功能,准确率略有上升,但基线方法的分类准确率依然均低于单机训练,其中表现最好的FedAMP 的分类准确率相较于单机训练降低了0.56 个百分点. FedAMN 是实验中效果最佳的方法,相较于单机训练,其分类准确率提升了1.94 个百分点,相较于FedAMP 分类准确率提升了2.5 个百分点. 这是由于FedAMP 可以促进具有相似特征和数据分布的客户端之间的协作,并疏远具有广泛不同特征的客户端之间的协作,从而缓解模型异构导致的模型融合效果差的情况.

4.3 训练时间对比实验

AMN 网络的引入增加了客户端的训练开销,同时增加了客户端和服务器之间的通信开销. 为了评估FedAMN 对各项资源的消耗,我们对不同方法的总训练时间进行了测试.

实验如表4 所示,ToAcc@70 表示不同方法的分类准确率第1 次达到70%所花费的训练时间,ToAcc@80 表示不同方法的分类准确率第1 次达到80%所花费的训练时间. 数值越小表示训练越快,资源消耗越小. 可以看出,单机训练方法、联邦学习方法和个性化联邦学习方法所需的训练时间依次递增. 由于FedAMN 需要额外训练AMN 网络,因此在ToAcc@70和ToAcc@80 情况下所花费的训练时间是最长的. 相较于分类准确率最高的基线方法FedAMP,FedAMN的训练时间在ToAcc@70 和ToAcc@80 中分别增加了6.7%和2.8%. 尽管FedAMN 的训练时间相对较长,但它所带来的模型分类准确率提升是更大的. 这使得FedAMN 更加适合对模型效果要求更高、对训练时间要求较为宽松的任务.

Table 4 Classification Accuracy of Different Methods in W&A China Dataset and Required Training Time表4 不同方法在W&A China 数据集中的分类准确率以及所需的训练时间

4.4 消融实验

在FedAMN 中我们引入AMN 来对客户端模型参数进行融合,2 层网络结构分别承担了不同的作用. 为了进一步说明2 层网络设计的必要性,我们通过消融实验来说明训练过程中各层对结果的影响.为了便于对比,我们对AMN 进行分解,增加了3 种变体来进行实验:

1)FTL. 优化目标与FedAMN 相同,但客户端不加入AMN 网络,仅利用全局模型来协助客户实现个性化模型的训练.

2) FedAMN-1 层. 客户端部署的AMN 网络仅包含第1 层结构,即对W进行融合处理.

3) FedAMN-2 层. 客户端部署的AMN 网络仅包含第2 层结构. 由于此时没有第1 层,因此只使用本地模型和全局模型作为网络输入.

4) FedAMN.具备完整2 层AMN 结构的个性化联邦学习训练.

FedAMN 及其3 种变体方法在2 个数据集上的分类准确率和分类精确率结果如图4 所示. 可以看出,FTL 方法的表现最弱,FedAMN-1 层和FedAMN-2 层略优于FTL,但弱于FedAMN. AMN 的加入使得客户端分类准确率在2 个数据集上相较于FTL 分别提升了3.59 个百分点和2.65 个百分点. 同时,我们发现AMN 中第1 层比第2 层为客户端个性化建模带来了更多的好处. 因为第1 层可以调整客户间有益信息的比例,过滤相似度更弱的客户端. 但尽管第2 层的益处略低于第1 层,全局模型信息的加入仍然带来进一步的性能提升. 所以AMN 的2 层结构代表了不同层面的信息融合,两者的结合可以在模型共性和个性间达到动态平衡,为客户间协作关系的刻画提供更全面的帮助.

Fig.4 Ablation experiments of FedAMN and its three variants on two data sets图4 FedAMN 及其3 种变体在2 个数据集上的消融实验

4.5 AMN 的可视化分析

为了对AMN 中注意力机制的有效性进行分析,我们以W&A China 数据集为例,对AMN 的2 层结构分别进行了可视化实验. 如图5 所示,c1~c9为9 个随机选取的位于北京和上海的空气质量监测站点. 其中c5和c9位于上海,其他的位于北京.图5(a)(b)为通信轮次为2 时的AMN 可视化结果,图5(c)(d)为通信轮次为50 时的AMN 可视化结果. 网格的颜色越浅,代表相关性越强,注意力分数越高. 可以看出,在通信轮次为2 时,AMN 的第1 层中每个客户端的注意力权重都趋向平均,第2 层的也相差不大(除了c5因为是上海的站点,具有和其他站点更加不同的特点,watt得到更高的注意力分数). 说明AMN 在训练初期并不能得到很好的训练,实现客户模型相关性的判断. 在通信轮次为50 时,同一个城市的客户端相互得到了更高的注意力分数. 第2 层中的watt所占比例显著增加.c5和c9也给予对方最高的注意力分数.

Fig.5 Results of attention score visualization for two-layer structure of AMN under different communication rounds图5 不同通信轮次下AMN 的2 层结构中的注意力分数可视化结果

Fig.6 Effect of changing the number of sites in different cities for collaborative training on the accuracy of observed client图6 不同城市及站点数量情况下的联邦训练对观测客户准确率的影响

2 组对比实验表明,在训练的早期,本地模型的训练不够充分,AMN 的加入并不能带来很好的效果. 然而,在训练的后期阶段,可以很好地利用AMN,将相似的客户端分组通过分配高注意力分数来加强协作.这也说明了渐进式训练方式的必要性.

4.6 客户端异质性对训练的影响

为了更好地探究客户端异质性对训练的影响,在本节中,我们对FedAMN 和其他基线方法在不同城市及站点数量情况下的分类准确率进行了测评.W&A China 数据集中包含中国4 个直辖市(北京、天津、上海、重庆)中的42 个监测站. 我们选取北京的1 个站点作为观测客户,依次加入北京、天津、上海、重庆的另外3 个站点进行联邦训练,再测量观测客户的分类准确率. 通过调整协同训练过程中参与城市和站点数量,比较不同个性化联邦学习方法对单个客户分类准确率的影响,以检验模型在面对数据异构情况下的训练表现是否良好. FedHealth 需要在服务器上构造源域数据集,在城市感知应用场景下难以实现,因此并未参与测试. 该实验中的对比方法采用联邦学习方法FedAvg 和个性化联邦学习方法FedPer,FTL,FedAMP.

实验结果如6 所示,当只有观测客户独自参与训练时,可以等同于单机训练,分类准确率为73.01 个百分点. 随着不同城市站点的依次加入,观测站点的分类准确率出现改变. 比较明显的是当加入3 个北京站点后,所有方法的分类准确率都出现了不同程度的上升. 而随着天津站点的加入,FedAvg 的分类准确率出现了明显下降. 此时,城市间站点的异构性对全局模型的影响初步显现,作为依赖全局模型实现模型优化的联邦学习方法,FedAvg 无法处理这种异构性,因此观测客户的分类准确率维持在较低水平. 而个性化联邦学习方法都在不同程度上有效处理此种程度的异质性. 当上海和重庆站点加入后,所有方法在观测客户上的分类准确率均有下降. 这是由于北京和天津在地理上更为接近,2 个城市具有相似的空气质量趋势,而上海和重庆作为南方城市,气候条件与北方差异较大,此时将它们放在一起训练会造成分类准确率下降. 尽管如此,FedAMN 仍然保持了最优的效果. 和次优方法FedAMP 相比,上海和重庆的站点加入训练后的分类准确率分别提升了1.67 个百分点和2.34 个百分点. 和联邦学习方法FedAvg 相比,上海和重庆站点加入训练后的分类准确率分别提升了10.94 个百分点和9.14 个百分点. 这表明AMN 的加入可以自动学习客户间相似性,在面对客户间严重的数据异构问题时具有更强的健壮性和有效性.

5 总 结

在个性化联邦学习训练过程中,由于客户端数据的隐私性,现有的方法难以细粒度地提取客户端特征,并定义它们之间的协作关系. 同时,数据异质情况下,全局模型的粗融合也会影响最终建模的性能效果. 更优的方式 ,是为每个客户端提供不同的协作关系. 具体来讲,促进具有相似特征和数据分布的客户端之间的协作,并疏远具有广泛不同特征的客户之间的协作. 因此,对于给定的客户端,如何自动提取与其类似的客户端,是一个具有挑战的问题. 此外,对于服务器-客户端框架,需要设计合理的协作训练方法.

为此,本文将个性化联邦学习中的协作关系刻画问题形式化为一个分布式元学习网络训练任务,并设计了一种注意力增强元学习网络(AMN)来解决客户端本地模型的个性化问题,有效地实现客户端个性与共性的权衡. 并提出了FedAMN 框架,将AMN 分布部署在客户端进行个性化模型训练. 考虑到客户端需要同时训练AMN 和本地网络,我们设计了渐进式的交替训练策略,以端到端的方式对2 个网络进行训练,缓解了元学习网络引入带来的计算量增加问题. 最终实现客户端的个性化注意力增强模型的生成. 为了证明FedAMN 的有效性,我们在2个基准数据集和8 种基准方法上进行了大量实验,相较于现有表现最优的个性化联邦学习方法,我们的方法在2 个数据集中平均分别提升了3.39%和2.45%的模型性能.

在未来的工作中,我们会继续研究FedAMN 中的模型压缩问题,以期进一步减少服务器和客户端之间的通信开销,实现个性化联邦学习技术在智能感知场景中的应用.

作者贡献声明:高雨佳提出了方法思路与实验方案,完成实验并撰写论文;王鹏飞提供理论指导、研究思路分析及论文修改意见;刘亮和马华东提供理论分析和技术指导,提出论文修改意见并最终审核.

猜你喜欢

联邦客户端准确率
乳腺超声检查诊断乳腺肿瘤的特异度及准确率分析
不同序列磁共振成像诊断脊柱损伤的临床准确率比较探讨
2015—2017 年宁夏各天气预报参考产品质量检验分析
一“炮”而红 音联邦SVSound 2000 Pro品鉴会完满举行
303A深圳市音联邦电气有限公司
高速公路车牌识别标识站准确率验证法
县级台在突发事件报道中如何应用手机客户端
孵化垂直频道:新闻客户端新策略
基于Vanconnect的智能家居瘦客户端的设计与实现
20年后捷克与斯洛伐克各界对联邦解体的反思