首页> 中国专利> 一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统

一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统

摘要

本发明涉及一种联邦多任务学习中基于特征提取‑子任务分类器的数据标签分类方法及系统,适用于中央节点式联邦学习系统。为了提升整体模型的有效性和精度并解决标签缺失数据的问题,本发明通过两步分离式的联邦多任务学习训练方式,实现了一个“特征提取‑子任务分类器”的统一网络架构设计。该设计方法能够解决联邦多任务学习中多标签数据的部分标签缺失问题并拥有较高的模型性能以及测试精度,最终训练出一个高性能多标签分类器网络,同时保护了用户节点的数据隐私。

著录项

  • 公开/公告号CN114882245A

    专利类型发明专利

  • 公开/公告日2022-08-09

    原文格式PDF

  • 申请/专利权人 山东大学;

    申请/专利号CN202210438889.6

  • 发明设计人 郭帅帅;王謇达;史高鑫;张海霞;

    申请日2022-04-22

  • 分类号G06V10/46(2022.01);G06V10/774(2022.01);G06V10/764(2022.01);G06V10/82(2022.01);G06V10/94(2022.01);G06N3/04(2006.01);G06N3/08(2006.01);G06N20/00(2019.01);G06K9/62(2022.01);G06F21/62(2013.01);

  • 代理机构济南金迪知识产权代理有限公司 37219;

  • 代理人杨树云

  • 地址 250199 山东省济南市历城区山大南路27号

  • 入库时间 2023-06-19 16:19:08

法律信息

  • 法律状态公告日

    法律状态信息

    法律状态

  • 2023-08-25

    授权

    发明专利权授予

  • 2022-08-26

    实质审查的生效 IPC(主分类):G06V10/46 专利申请号:2022104388896 申请日:20220422

    实质审查的生效

  • 2022-08-09

    公开

    发明专利申请公布

说明书

技术领域

本发明涉及一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统,属于人工智能技术领域。

背景技术

联邦学习属于分布式机器学习,是一种新兴的机器学习框架。随着大数据时代的到来,用户的数据安全和隐私保护越来越重要,诸多国家也出台了隐私保护相关法律法规。而对于训练大规模机器学习模型,传统的分布式机器学习往往不涉及数据隐私问题,中央服务器对计算节点以及其中的数据具有较高的控制权。2016年,谷歌公司提出了联邦学习,旨在每个用户数据不出本地仍可参与到模型的训练中去,以实现保护各参与者数据安全的目的。联邦学习中各个用户节点通过本地的私有数据训练模型,经中央服务器协调,聚合各个用户节点的模型参数,并更新全局的模型。这期间不涉及数据的传输,在很大的程度上保护了数据安全。详见文献[1]:Mcmahan H B,Moore E,D Ramage,et al.Communication-Efficient Learning of Deep Networks from Decentralized Data[J].2016.。

传统的机器学习训练中,数据往往是单标签的,即每个实例都与仅一个标签相关联,以表示其概念类的归属。然而,在许多实际应用中,一个对象通常会附带有多个标签,即一个实例对应于一组标签。例如,在文本分类任务中,一个文档可能属于多个主题,如“小说”、“社会”;在图像分类任务中,一幅图像可能属于多个语义,如“猫”、“白色”。使用多标签数据的多标签学习在从文档分类到基因功能预测和自动图像注释等诸多应用中起着至关重要的作用。多标签分类中,一种常见方法是问题转换,即将一个多标签问题转化为一个或多个单标签分类器来进行分类,再将其转换为多标签表示。详见文献[2]:Read J,Pfahringer B,Holmes G,et al.Classifier chains for multi-label classification[J].Machine Learning,2011,85(3):333-359.。

在多标签学习中,一个常见的假设是,所有的类标签及其值在训练过程之前都被观测到。然而,在一些实际应用中,由于标签的标注成本很高、一些标签在标注过程中的刻意省略以及部分标签存在未知性等因素,故一些观测到的标签是缺失的,甚至有部分标签并没有被观测到。这给多标签分类任务造成了巨大的困难。因此,如何在多标签分类任务中,解决标签缺失的问题并保证较好的分类精度,得到了广泛关注。用现有技术来解决标签缺失数据的多标签学习问题,详见文献[3]:Sun Y Y,Zhang Y,Zhou Z H.Multi-labellearning with weak label[C]//Twenty-fourth AAAI conference on artificialintelligence.2010.,一个基本的先决条件是每个标签至少有一个正向的数据示例,即每个标签至少在数据中出现一次。但此类方法无法解决某一标签完全缺失的问题,有一定的局限性,实用性不足。

发明内容

针对现有技术的不足,本发明提供了一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法。

本发明通过两步分离式的联邦多任务学习训练方式,实现了一个“特征提取-子任务分类器”的统一网络架构设计。所有用户节点参与构建适用于所有用户数据的特征提取网络,该特征提取网络在所给用户数据上具有普适性。原始训练图像经过特征提取网络,输出提取显著特征后的图像数据,使得在下一步分类器网络的训练中降低训练损失,提高测试精度。通过子任务分类器网络,使得某些用户节点的数据缺失某些标签而不能完成模型训练的问题得以解决。其中,每个子任务分类器网络不再单独训练特征提取层,而是通过训练一个针对于所有用户的特征提取网络,完成输入图像的特征提取,降低子任务分类器网络模型复杂度。该设计方法能够解决联邦多任务学习中多标签数据的部分标签缺失问题并拥有较高的模型性能以及测试精度,最终训练出一个高性能多标签分类器网络,同时保护了用户节点的数据隐私。

本发明还提供了一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类系统。

术语解释:

1、MBGD:小批量梯度下降法;

2、MSELoss:均方误差损失函数;

3、CrossEntropyLoss:交叉熵损失函数;

4、One-Hot:独热编码,又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效。

本发明的技术方案为:

一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,适用于中央节点式联邦学习系统,所述中央节点式联邦学习系统包括M个用户节点和1个中央服务器,每个用户节点均与中心服务器相连接;设定所有用户的训练数据均为多标签数据,且来自于同一个特征空间,标签的总数为L;对任意的用户,其每一个本地数据点都拥有相同种类的标签;第m个用户拥有的本地数据数目用K

构建并训练全局模型,全局模型包括特征提取网络以及多个分类器网络;

将待分类的图像输入到训练好的全局模型,图像数据经过特征提取网络,提取特征;提取特征后的图像数据再经过所有分类器网络,每个分类器网络分别输出该待分类的图像对于每一种标签中,属于各个类别的概率输出值;每一种标签选择概率输出值最大的类别作为此标签的分类结果,最终输出每一种标签的分类结果;

其中,全局模型的训练过程为:

第一步,训练特征提取网络:

在第t个特征提取网络训练周期中,用户节点m收到由中央服务器广播的最新特征提取网络的模型参数w

第二步,训练多个分类器网络:

根据每个用户节点对应的数据标签对用户节点进行分组,设定分为L组,第i组的用户节点个数记为M

对于第i组的所有用户节点,目标为训练一个分类器网络i,其中,i表示在所有分类器网络中此分类器网络的索引号;分类器网络i在第t个训练周期中,第m

在该用户组所有用户节点完成一轮训练后,各用户节点将各自更新后的本地分类器网络i的模型参数上传至中央服务器,并在中央服务器进行参数的聚合,得到一个新的分类器网络i,其模型参数为

对于全部L组用户节点均进行上述分类器网络的训练过程,直至中央服务器端的所有分类器网络收敛。

全局模型的训练过程包括:

在中央节点式联邦学习系统中,基于所有用户拥有的数据,构建一个统一的、适用于所有用户的特征提取网络

所有用户节点使用此特征提取网络,执行对本地数据的特征提取,得到提取显著特征后的图像数据;

在中央节点式联邦学习系统中,设定每一个用户节点的训练数据只拥有部分标签,即有一些标签是缺失的,且同一用户的数据所缺失标签是一致的;首先,根据用户节点拥有的标签进行分组,将拥有同一标签的用户节点集合称为一个用户组,形成多个用户组;之后;对于每一用户组的用户节点,均通过联邦学习的形式,训练出一个适用于此组标签的分类器网络;对于第i个标签的用户组所训练的分类器网络表示为

根据本发明优选的,定义特征提取网络的学习目标是最小化一个经验损失函数,如式(I)、(II)所示:

式(I)中,F(w)表示全局的平均训练损失,w表示d维的模型参数向量,F

根据本发明优选的,在用户节点m收到由中央服务器广播的最新特征提取网络的模型参数w

式(III)中,

在第t个特征提取网络训练周期中,所有的用户节点选择在本地通过MBGD法进行多次的本地特征提取网络训练损失的梯度更新;然后再将最新本地特征提取网络训练损失的梯度{g

根据本发明优选的,特征提取网络

根据本发明优选的,本地特征提取网络的损失函数选用MSELoss损失函数f(x

f(x

其中,x

根据本发明优选的,第i个分类器网络的局平均训练损失F

式(VI)和(VII)中,上标i表示该变量对应了第i个分类器网络;F

根据本发明优选的,每个分类器网络包括线性层和激活层,输入特征提取后的图像数据后,分类器网络分别输出图片对于特定一种标签中,属于各个类别的概率输出值,每一种标签选择概率最大的类别作为此标签的分类结果。

根据本发明优选的,每个分类器网络的本地损失函数均选用CrossEntropyLoss损失函数,其计算方法如(VIII)式所示:

如(VIII)中,输入x

一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现联邦多任务学习中解决子任务数据标签缺失的方法的步骤。

一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现联邦多任务学习中解决子任务数据标签缺失的方法的步骤。

一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类系统,包括:

特征提取模块,被配置为,对待分类的图像进行特征提取,提取出图像数据的主要特征;使得图片数据的RGB特征分量增加,总特征数目明显提升;

标签分类模块,被配置为,从分类器网络中输出对应某一标签的分类结果。

本发明的有益效果为:

本发明针对中央节点式联邦学习系统应用场景,提出了一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法。通过两步分离式的联邦多任务学习架构,实现了联合“特征提取网络”、“分类器网络”的统一设计。原始训练图像经过统一的特征提取网络,输出提取显著特征后的图像数据,使得在下一步分类器网络的训练中降低训练损失,提高测试精度以及模型的有效性,同时,每个子任务分类器网络不再单独训练特征提取层,而是通过训练一个针对于所有用户的特征提取网络,降低了模型复杂度。通过分类器网络,使得部分用户节点数据缺失某些标签而不能完成模型训练的问题得以解决。以联邦学习的方式进行分组多任务训练,能够在保护用户数据隐私的前提下,训练出一个高性能的多标签分类器网络。

附图说明

图1是本发明联邦多任务学习中解决子任务数据标签缺失的方法的流程框图;

图2(a)是本发明在CelebA数据集上,对于其中标签1的分类子任务中的训练损失示意图;

图2(b)是本发明在CelebA数据集上,对于其中标签2的分类子任务中的训练损失示意图;

图3(a)是本发明在CelebA数据集上,对于其中标签1的分类子任务中的测试精度示意图;

图3(b)是本发明在CelebA数据集上,对于其中标签2的分类子任务中的测试精度示意图;

图4是卷积自编码器网络的结构示意图;

图5是分类器网络的网络的结构示意图。

具体实施方式

下面结合说明书附图和实施例对本发明作进一步限定,但不限于此。

实施例1

一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,适用于中央节点式联邦学习系统,所述中央节点式联邦学习系统包括M个用户节点和1个中央服务器,每个用户节点均与中心服务器相连接;设定所有用户的训练数据均为多标签数据,且来自于同一个特征空间,标签的总数为L;对任意的用户,其每一个本地数据点都拥有相同种类的标签;第m个用户拥有的本地数据数目用K

构建并训练全局模型,全局模型包括特征提取网络以及多个分类器网络;特征提取-子任务分类器即全局模型。

将待分类的图像输入到训练好的全局模型,图像数据经过特征提取网络,提取特征;提取特征后的图像数据再经过所有分类器网络,每个分类器网络分别输出该待分类的图像对于每一种标签中,属于各个类别的概率输出值;每一种标签选择概率输出值最大的类别作为此标签的分类结果,最终输出每一种标签的分类结果;

其中,全局模型的训练过程为:

第一步,训练特征提取网络:

在第t个特征提取网络训练周期中,用户节点m收到由中央服务器广播的最新特征提取网络的模型参数w

第二步,训练多个分类器网络:

根据每个用户节点对应的数据标签对用户节点进行分组,设定分为L组,第i组的用户节点个数记为M

对于第i组的所有用户节点,目标为训练一个分类器网络i,其中,i表示在所有分类器网络中此分类器网络的索引号;分类器网络i在第t个训练周期中,第m

在该用户组所有用户节点完成一轮训练后,各用户节点将各自更新后的本地分类器网络i的模型参数上传至中央服务器,并在中央服务器进行参数的聚合,得到一个新的分类器网络i,其模型参数为

对于全部L组用户节点均进行上述分类器网络的训练过程,直至中央服务器端的所有分类器网络收敛。

实施例2

根据实施例1所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其区别在于:

全局模型的训练过程包括:

为了提升整体模型的精度和有效性并解决标签缺失数据的问题,本发明采用“特征提取、子任务分类器”的两步分离式的网络架构;

为了降低后续子任务分类器网络的训练损失,提高测试精度,在中央节点式联邦学习系统中,基于所有用户拥有的数据,构建一个统一的、适用于所有用户的特征提取网络

所有用户节点使用此特征提取网络,执行对本地数据的特征提取,得到提取显著特征后的图像数据;用于后续的分类器模型训练;

在中央节点式联邦学习系统中,设定每一个用户节点的训练数据只拥有部分标签,即有一些标签是缺失的,且同一用户的数据所缺失标签是一致的;首先,运用多任务学习的思想,根据用户节点拥有的标签进行分组,将拥有同一标签的用户节点集合称为一个用户组,形成多个用户组;之后;对于每一用户组的用户节点,均通过联邦学习的形式,训练出一个适用于此组标签的分类器网络;对于第i个标签的用户组所训练的分类器网络表示为

将特征提取网络

定义特征提取网络的学习目标是最小化一个经验损失函数,如式(I)、(II)所示:

式(I)中,F(w)表示全局的平均训练损失,w表示d维的模型参数向量,F

在用户节点m收到由中央服务器广播的最新特征提取网络的模型参数w

式(III)中,

在第t个特征提取网络训练周期中,所有的用户节点选择在本地通过MBGD法进行多次的本地特征提取网络训练损失的梯度更新;然后再将最新本地特征提取网络训练损失的梯度{g

特征提取网络

对于训练完毕的特征提取网络,只采用编码器部分。将原始图像数据输入到编码器,得到提取显著特征后的图像数据,用于后续的分类器网络训练。

本地特征提取网络的损失函数选用MSELoss损失函数f(x

f(x

其中,x

第i个分类器网络的局平均训练损失F

式(VI)和(VII)中,上标i表示该变量对应了第i个分类器网络;F

因为每个用户节点的训练数据已经经过特征提取,每个分类器网络包括线性层和激活层,网络结构如图5所示。对于某一个分类器网络,输入特征提取后的图像数据后,分类器网络分别输出图片对于特定一种标签中,属于各个类别的概率输出值,每一种标签选择概率最大的类别作为此标签的分类结果。

每个分类器网络的本地损失函数均选用CrossEntropyLoss损失函数,其计算方法如(VIII)式所示:

如(VIII)中,输入x

选取CelebA数据集中的40000个数据点,下放到所有用户节点。这些数据点只附带有40个原始标签的中两个标签,每一个数据点只有一个标签,即为标签缺失。

图2(a)是本发明在CelebA数据集上,对于其中标签1的分类子任务中的训练损失示意图;图2(b)是本发明在CelebA数据集上,对于其中标签2的分类子任务中的训练损失示意图;横坐标为训练轮次,纵坐标为训练数据的损失。

图3(a)是本发明在CelebA数据集上,对于其中标签1的分类子任务中的测试精度示意图;图3(b)是本发明在CelebA数据集上,对于其中标签2的分类子任务中的测试精度示意图;其中,横坐标为训练轮次,纵坐标为测试数据的精度。

由图2(a)、图2(b)、图3(a)、图3(b)可知,由于在特征提取网络中,经卷积编码器提取出图像的显著特征,能提高子任务分类器的性能、稳定性以及分类精度。通过此“特征提取-子任务分类器”的网络架构,能够有效解决多标签数据的标签缺失的问题。同时以联邦学习的方式,在保证了参与训练用户的数据隐私的前提下,仍然能够保持较高模型性能以及测试精度,可见此设计的有效性。

将本发明应用到医学图像的标签识别上,训练数据为所有用户某医学图像,多类标签为不同病症或科室的诊断结果(存在标签缺失状况)。在训练医学图像智能诊断模型中,实现联邦多任务学习中解决子任务数据标签缺失的方法的步骤。

实施例3

一种计算机设备,包括存储器和处理器,存储器存储有计算机程序,处理器执行计算机程序时实现实施例1或2一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法的步骤。

实施例4

一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现实施例1或2一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法的步骤。

实施例5

一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类系统,包括:

特征提取模块,被配置为,对待分类的图像进行特征提取,提取出图像数据的主要特征;使得图片数据的RGB特征分量增加,总特征数目明显提升;

标签分类模块,被配置为,从分类器网络中输出对应某一标签的分类结果。

去获取专利,查看全文>

相似文献

  • 专利
  • 中文文献
  • 外文文献
获取专利

客服邮箱:kefu@zhangqiaokeyan.com

京公网安备:11010802029741号 ICP备案号:京ICP备15016152号-6 六维联合信息科技 (北京) 有限公司©版权所有
  • 客服微信

  • 服务号