首页> 中国专利> 一种基于两阶段聚类提升联邦学习模型性能的训练方法和存储设备

一种基于两阶段聚类提升联邦学习模型性能的训练方法和存储设备

摘要

本发明涉及模型训练技术领域,特别涉及一种基于两阶段聚类提升联邦学习模型性能的训练方法和存储设备。所述一种基于两阶段聚类提升联邦学习模型性能的训练方法,包括步骤:获取每个客户在本地通过联邦学习训练好后的模型;通过K‑Center聚类方法将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器;通过该步骤是对具有相似数据集的客户端的集群进行模型训练,减少了客户端数据分布的差异,加快收敛速度。将每个集群中的客户端随机分组为两个以上细粒度集群,采用预设算法对每个细粒度集群进行训练。通过进一步采用细粒度聚类方法来拉平了原始数据分布。通过以上步骤大大提高了联邦学习在Non‑IID环境下的模型性能。

著录项

  • 公开/公告号CN113313266A

    专利类型发明专利

  • 公开/公告日2021-08-27

    原文格式PDF

  • 申请/专利权人 厦门大学;

    申请/专利号CN202110661289.1

  • 发明设计人 翁建庆;苏松志;

    申请日2021-06-15

  • 分类号G06N20/00(20190101);G06K9/62(20060101);

  • 代理机构35219 福州市景弘专利代理事务所(普通合伙);

  • 代理人魏小霞;徐宝珺

  • 地址 361005 福建省厦门市思明区思明南路422号

  • 入库时间 2023-06-19 12:22:51

说明书

技术领域

本发明涉及模型训练技术领域,特别涉及一种基于两阶段聚类提升联邦学习模型性能的训练方法和存储设备。

背景技术

我们知道海量的训练样本是训练一个高精度模型的重要和必要前提。而随着大数据和5G时代的来临,企业和公司每天都可以从用户的手机、可穿戴设备等边缘终端产生大量的数据。这使得企业之间可以选择合作训练模型,主要通过把他们各自产生的数据传送到第三方中心服务器上,在中心服务器上利用全部的传送数据来训练一个高精度模型。但这样做不仅会导致高传输延迟、第三方服务器容量不足等问题,在一些金融和医疗等领域,通常还会因为用户数据的隐私敏感性而使得企业在源头就无法传送数据给第三方。

目前业界流行的一种解决方案是使用联邦学习。联邦学习是一种新颖的分布式合作学习方法,可以用来协同训练高性能模型。与传统的集中式机器学习不同,联邦学习把模型从第三方中心服务器推送到各个客户(这里的客户指向参与合作训练模型的企业或机构)中,网络训练的计算主要在本地客户中进行,因此用户的数据隐私可以得到很好的保护。

传统的联邦学习范式包含以下过程:(1)中心服务器首先将初始化的模型推送给各个客户,(2)每个客户从中心服务器接收到模型后,使用自己本地的数据集来继续训练。(3)每个客户将他们各自训练好的模型上传到中心服务器上。(4)中心服务器将所有在线客户传送上来的模型做加权平均,聚合成单个模型。上面四个步骤持续重复执行直到模型的训练达到了收敛效果。目前这种联邦学习范式代表的算法有FedAvg和FedProx。

其存在以下缺点:

缺点1:

训练数据的IID(独立同分布)采样是确保随机梯度是完整梯度的无偏估计的重要前提,只有当客户之间的数据是IID分布时,联邦学习的训练效果才会显著。但这种IID假设在现实中难以保证。由于不同的业务场景和用户行为,不同客户之间的数据集通常是异质的,这也导致联邦学习在实际应用中会发生模型性能下降。

缺点2:

联邦学习试图学习一个全局共享的模型来适应所有客户端的数据分布,但是当跨客户端的数据是异质的时候,模型会偏离其最佳方向,导致性能下降。

发明内容

为此,需要提供一种基于两阶段聚类提升联邦学习模型性能的训练方法,用以解决联邦学习在Non-IID环境下模型性能低的问题。具体技术方案如下:

一种基于两阶段聚类提升联邦学习模型性能的训练方法,包括步骤:

获取每个客户在本地通过联邦学习训练好后的模型;

通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器;

通过第二阶段聚类将每个集群中的客户端分组为两个以上细粒度集群,采用预设算法对每个细粒度集群进行训练。

进一步的,所述“通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器”,具体还包括步骤:

通过K-Center聚类方法将所有客户端分组为两个以上集群。

进一步的,所述“通过K-Center聚类方法将所有客户端分组为两个以上集群”,具体还包括步骤:

根据客户端本地模型参数的相似性构建客户集群,不同的集群间互不干扰地进行联邦学习。

进一步的,所述“采用预设算法对每个细粒度集群进行训练”,具体还包括步骤:

通过“伪”小批量SGD对每个细粒度集群进行训练。

为解决上述技术问题,还提供了一种存储设备,具体技术方案如下:

一种存储设备,其中存储有指令集,所述指令集用于执行:获取每个客户在本地通过联邦学习训练好后的模型;

通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器;

通过第二阶段聚类将每个集群中的客户端分组为两个以上细粒度集群,采用预设算法对每个细粒度集群进行训练。

进一步的,所述指令集还用于执行:所述“通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器”,具体还包括步骤:

通过K-Center聚类方法将所有客户端分组为两个以上集群。

进一步的,所述指令集还用于执行:所述“通过K-Center聚类方法将所有客户端分组为两个以上集群”,具体还包括步骤:

根据客户端本地模型参数的相似性构建客户集群,不同的集群间互不干扰地进行联邦学习。

进一步的,所述指令集还用于执行:所述“采用预设算法对每个细粒度集群进行训练”,具体还包括步骤:

通过“伪”小批量SGD对每个细粒度集群进行训练。

本发明的有益效果是:一种基于两阶段聚类提升联邦学习模型性能的训练方法,包括步骤:获取每个客户在本地通过联邦学习训练好后的模型;通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器;通过第二阶段聚类将每个集群中的客户端分组为两个以上细粒度集群,采用预设算法对每个细粒度集群进行训练。通过以上方法,采用两阶段聚类机制,可以稳定消除Non-IID数据的负面影响并提高学习的收敛速度。

进一步的,通过第一阶段聚类来根据客户端本地模型参数的相似性来构建客户集群,不同的集群之间互不干扰地进行联邦学习。从而不再是学习一个全局共享的模型来适应所有客户端的数据分布,因此可以有效降低数据分布偏倚,从而提高全局模型的性能。

进一步的,通过第二阶段聚类来进一步地降低客户端之间的数据分布差异,在第一阶段聚类的基础上,继续把每个客户划分为多个细粒度集群,把细粒度集群当作一个训练模型的单位,利用提出的“伪”小批量SGD训练方法,可以达到接近通用的小批量SGD训练效果并且保护了用户的数据隐私。

附图说明

图1为具体实施方式所述一种基于两阶段聚类提升联邦学习模型性能的训练方法的流程图;

图2为具体实施方式所述传统的联邦学习范式示意图;

图3为具体实施方式所述第一阶段聚类示意图;

图4为具体实施方式所述第二阶段聚类示意图;

图5为具体实施方式所述“伪”小批量SGD训练方法示意图;

图6为具体实施方式所述一种存储设备的模块示意图。

附图标记说明:

600、存储设备。

具体实施方式

为详细说明技术方案的技术内容、构造特征、所实现目的及效果,以下结合具体实施例并配合附图详予说明。

请参阅图1至图5,在本实施方式中,一种基于两阶段聚类提升联邦学习模型性能的训练方法可应用在一种存储设备上,所述存储设备包括但不限于:个人计算机、服务器、通用计算机、专用计算机、网络设备、嵌入式设备、可编程设备、智能移动终端等。

首先对本申请的核心技术思想进行说明:因发现客户之间数据分布的差异与其模型偏离度(model divergence)之间存在联系。因此,本申请首先使用K-Center聚类方法将所有客户端分组为多个集群(第一阶段聚类),每个集群对应一个不同的中央服务器。第一阶段聚类策略可以训练多个不相交的模型,这些模型针对具有相似数据集的客户端的集群,从而减少客户端数据分布的差异,加快收敛速度。

因在理论上证明了,在训练中,如果客户端的训练数据更加均匀分布,则可以减少模型偏离度的增长,从而提高在Non-IID环境下的训练性能。因此,本申请进一步采用细粒度聚类方法来拉平原始数据分布。首先将每个集群中的客户端随机分组为多个细粒度集群(第二阶段聚类),每个集群执行“伪”小批量SGD来训练本地模型,这种“伪”小批量SGD方法对每个细粒度集群进行通用的小批量SGD训练,并将数据保存在了本地。本申请的第二阶段聚类策略可以使客户端之间的数据分布趋于平坦,并减少由Non-IID环境引起的模型偏离度(model divergence)的增长。

以下具体展开说明:

步骤S101:获取每个客户在本地通过联邦学习训练好后的模型。具体可如下:首先如图2所示,执行传统的联邦学习过程,直至训练到t-1轮为止。在第t轮中,每个客户在本地训练好后将模型传送至中心服务器上。

如图3所示,步骤S102:通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器。在本实施方式中,具体采用通过K-Center聚类方法将所有客户端分组为两个以上集群。K-Centers聚类算法是针对每个客户的模型权重来进行聚类的,是根据客户端本地模型参数的相似性构建客户集群,相似性度高的客户端将其构建到一个集群中。每个集群都对应一个新的且独立中心服务器来负责搜集用户上传的模型并执行模型聚合操作。从此以后,各个聚类之间的训练独立且互不干扰。

如图4所示,步骤S103:通过第二阶段聚类将每个集群中的客户端分组为两个以上细粒度集群,采用预设算法对每个细粒度集群进行训练。具体可如下:为了使得客户端之间的数据分布更均匀,在每个大集群中进一步采用聚类把大集群里面的每个客户再归为若干个细粒度集群。这里的聚类算法既可以是随机划分,也可以是按照客户所在的地域来划分,只要把若干个客户划分为一个细粒度集群就可以了。

在第二步聚类后会得到若干个细粒度集群,这时把这些细粒度集群当作一个新的“客户”来看待,可以证明这些“新”客户之间的数据集分布比划分为细粒度集群前的数据集分布要更加均匀。因此以细粒度集群为一个训练单位来训练一个本地模型,接着把每个训练后的本地模型上传到对应的中心服务器即可。

对于每一个细粒度集群,为了同时利用到集群里面的每一个客户的训练集,同时还要不侵犯数据隐私的情况下训练出一个模型,在本实施方式中,预设算法采用了“伪”小批量SGD的训练方法。这是一种序列化训练模型的方式,首先每轮从细粒度集群里随机挑选出一个客户,客户半训练完后将它的模型传送个下一个客户,下一个客户用上一个客户传来的模型作为初始化模型在它自己的本地数据集里训练,接着再把训练模型传给下一个客户,持续这个过程直到这个细粒度集群内的所有客户都在自己本地上执行过模型训练,并在细粒度集群内序列化训练了若干轮为止。

本实施方式中的“伪”小批量SGD训练其实是一种序列行训练方式,但是它能和通用的小批量SGD训练方式达到相同的效果。如图5所示,其上方是把所有的客户的数据聚集在一起然后训练一个模型,这是通用的小批量SGD算法的训练过程,而图5下方则是把训练的模型按顺序地往下传送给其它客户接着训练,这就是本申请改进后的“伪”小批量SGD算法,且本申请提出的“伪”小批量SGD训练还能够保证数据不会离开本地,从而保护了用户数据隐私。

一种基于两阶段聚类提升联邦学习模型性能的训练方法,包括步骤:获取每个客户在本地通过联邦学习训练好后的模型;通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器;通过第二阶段聚类将每个集群中的客户端分组为两个以上细粒度集群,采用预设算法对每个细粒度集群进行训练。通过以上方法,采用两阶段聚类机制,可以稳定消除Non-IID数据的负面影响并提高学习的收敛速度。

进一步的,通过第一阶段聚类来根据客户端本地模型参数的相似性来构建客户集群,不同的集群之间互不干扰地进行联邦学习。从而不再是学习一个全局共享的模型来适应所有客户端的数据分布,因此可以有效降低数据分布偏倚,从而提高全局模型的性能。

进一步的,通过第二阶段聚类来进一步地降低客户端之间的数据分布差异,在第一阶段聚类的基础上,继续把每个客户划分为多个细粒度集群,把细粒度集群当作一个训练模型的单位,利用提出的“伪”小批量SGD训练方法,可以达到接近通用的小批量SGD训练效果并且保护了用户的数据隐私。

请参阅图2至图6,在本实施方式中,一种存储设备600的具体实施方式如下:

一种存储设备600,其中存储有指令集,所述指令集用于执行:获取每个客户在本地通过联邦学习训练好后的模型。具体可如下:首先如图2所示,执行传统的联邦学习过程,直至训练到t-1轮为止。在第t轮中,每个客户在本地训练好后将模型传送至中心服务器上。

通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器。在本实施方式中,具体采用通过K-Center聚类方法将所有客户端分组为两个以上集群。K-Centers聚类算法是针对每个客户的模型权重来进行聚类的,是根据客户端本地模型参数的相似性构建客户集群,相似性度高的客户端将其构建到一个集群中。每个集群都对应一个新的且独立中心服务器来负责搜集用户上传的模型并执行模型聚合操作。从此以后,各个聚类之间的训练独立且互不干扰。

通过第二阶段聚类将每个集群中的客户端分组为两个以上细粒度集群,采用预设算法对每个细粒度集群进行训练。。具体可如下:为了使得客户端之间的数据分布更均匀,在每个大集群中进一步采用聚类把大集群里面的每个客户再归为若干个细粒度集群。这里的聚类算法既可以是随机划分,也可以是按照客户所在的地域来划分,只要把若干个客户划分为一个细粒度集群就可以了。

在第二步聚类后会得到若干个细粒度集群,这时把这些细粒度集群当作一个新的“客户”来看待,可以证明这些“新”客户之间的数据集分布比划分为细粒度集群前的数据集分布要更加均匀。因此以细粒度集群为一个训练单位来训练一个本地模型,接着把每个训练后的本地模型上传到对应的中心服务器即可。

对于每一个细粒度集群,为了同时利用到集群里面的每一个客户的训练集,同时还要不侵犯数据隐私的情况下训练出一个模型,在本实施方式中,预设算法采用了“伪”小批量SGD的训练方法。这是一种序列化训练模型的方式,首先每轮从细粒度集群里随机挑选出一个客户,客户半训练完后将它的模型传送个下一个客户,下一个客户用上一个客户传来的模型作为初始化模型在它自己的本地数据集里训练,接着再把训练模型传给下一个客户,持续这个过程直到这个细粒度集群内的所有客户都在自己本地上执行过模型训练,并在细粒度集群内序列化训练了若干轮为止。

如图5所示,本实施方式中的“伪”小批量SGD训练其实是一种序列行训练方式,但是它能和通用的小批量SGD训练方式达到相同的效果。且本申请提出的“伪”小批量SGD训练还能够保证数据不会离开本地,从而保护了用户数据隐私。

一种存储设备600,其中存储有指令集,所述指令集用于执行:获取每个客户在本地通过联邦学习训练好后的模型;通过第一阶段聚类将所有客户端分组为两个以上集群,每个集群对应一个不同的中央服务器;通过第二阶段聚类将每个集群中的客户端分组为两个以上细粒度集群,采用预设算法对每个细粒度集群进行训练。通过以上指令集执行,采用两阶段聚类机制,可以稳定消除Non-IID数据的负面影响并提高学习的收敛速度。

进一步的,通过第一阶段聚类来根据客户端本地模型参数的相似性来构建客户集群,不同的集群之间互不干扰地进行联邦学习。从而不再是学习一个全局共享的模型来适应所有客户端的数据分布,因此可以有效降低数据分布偏倚,从而提高全局模型的性能。

进一步的,通过第二阶段聚类来进一步地降低客户端之间的数据分布差异,在第一阶段聚类的基础上,继续把每个客户划分为多个细粒度集群,把细粒度集群当作一个训练模型的单位,利用提出的“伪”小批量SGD训练方法,可以达到接近通用的小批量SGD训练效果并且保护了用户的数据隐私。

以下附上本申请的一些证明过程:

首先证明把所有客户划分为若干个细粒度集群后(此时把该集群中所有的数据集都加起来)的数据集分布会比划分前更加均匀。

假设有K个用户,将它们每S个分为一组,一共就有

接下来证明

上式说明了细粒度集群j上的数据分布与数据集总分布之间的EMD小于将客户端划分到第j组细粒度集群之前的EMD,这说明这种简单的聚类策略可以帮助我们将数据分布拉平。

Appendix B.

proof of

We convert the proof of the above inequality into the following:

Substituting Eq.(5)into inequality(8),we can get:

Because inequality

holds,so(9)holds.

Therefore.we have completed the proof of

接下来证明如果客户的训练数据更加均匀分布,则可以减少模型偏离度的增长,从而提高在Non-IID环境下的训练性能。

首先根据现有论文《Federated learning with non-iid data》,可以得出下列不等式:

上式的

当我们将K个客户按S个客户为一组分为若干个细粒度集群后,其模型偏离度变为如下:

现在只需要证明

Appendix A.

Proof of

Proof of inequality(1)is equivalent to Proof of inequality(2)

Let n′

Because

According to Binomial Theorem,we can get

and

where

Since we only consider the C-class classification problem,we canrewrite the discrete probability distribution as

where

Substitute Eq.(3),(4),(5)into inequality(2),we now need to prove

which is equivalent to prove

When t is 0 or 1,inequality(6)apparently establish.When t>1,wereplace the discrete probability distribution as shown in(5).Thereforeinequality(6)become:

Let

Next we provide the proof of inequality(7),First we have thefollowing lemma,which is the famous

convert the inequality (7)to:

we just need to apply

and we can directly prove the inequality(7).

So far.we have completed the proof of

需要说明的是,尽管在本文中已经对上述各实施例进行了描述,但并非因此限制本发明的专利保护范围。因此,基于本发明的创新理念,对本文所述实施例进行的变更和修改,或利用本发明说明书及附图内容所作的等效结构或等效流程变换,直接或间接地将以上技术方案运用在其他相关的技术领域,均包括在本发明的专利保护范围之内。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号