loading...
SenSys21-FedMask: Joint Computation and Communication-Efficient Personalized Federated Learning via Heterogeneous Masking
Published in:2022-12-24 | category: About Thesis

SenSys21-FedMask: Joint Computation and Communication-Efficient Personalized Federated Learning via Heterogeneous Masking

image-20221224150046850

本篇工作和MobiCom21-Hermes: An Efficient Federated Learning Framework for Heterogeneous Mobile的思想本质上一致。

FedMask improves the inference accuracy by 28.47% and reduces the communication cost and the computation cost by 34.48× and 2.44×. FedMask also achieves 1.56× inference speedup and reduces the energy consumption by 1.78×.

Thesis inference accuracy communication cost computation cost inference speedup energy consumption
SenSys21 28.47% 34.48x 2.44x 1.56x 1.78x
MobiCom21 32.17% 3.48x 1.83x 1.8x

每个设备学习一个稀疏的binary mask(每个网络参数1个bit),保持本地模型的参数不变,在服务器和设备之间只传输binary mask。

和经典的FL学习共享模型不同,每个device通过将学到的binary mask运用到本地模型固定的参数上去得到一个个性化的和结构稀疏的模型。

  • 只在device和server之间传输mask的话,本地得到的其它server的信息实质上只能是本地模型的参数哪些地方保留,哪些地方剪掉这样的信息,本地想取得不错的效果,要求本地能够训练出不错的模型,如果本地的数据如果太少,本地就无法训练出不错的模型,因此这个方法对本地数据很少的情况可能并不适用。
  • 为什么不把这个方法和MobiCom21的方法结合起来去做一个优化?
  • 剪枝的方法需要对每一个特定的模型做特定的剪枝策略设计和算法设计,并不是一个通用的方法。
  • 该工作和MobiCom21那一篇的主要精力其实都集中于如何对手机上的模型进行剪枝,关于中央服务器聚合的部分,其实花的精力并不多。

Background(当前存在什么问题)

和MobiCom21的那篇基本一致。

image-20221224152144195

Background增加了计算成本论述的一小节。

Challenges(当前的挑战)

  • 第一个挑战是如何联合优化通信和计算效率,Fedmask要去优化binary masks而不是去优化local model parameters只把optimized binary masks送到中央服务器。因此,需要设计一个优化binary mask的训练方法,当前的SGD方法不可用因为mask中的元素都binary value,0值会阻止梯度下降的反向传播,因此需要设计一个binary mask的训练方法。此外,优化binary mask的时候需要设计限制,否则就会变成非结构化稀疏方法,对硬件来说并不友好。
  • 第二个挑战是如何保留每个device的个性,当前的剪枝方法是个model weights这种浮点值设计的,这种方法不能被直接用到binary values上,所以也不能直接运用到FedMask的binary masks上,因此需要设计一个可以生成异构binary mask的方法。
  • 第三个挑战是如何在保留device个性化的前提下聚合异构的二进制mask。聚合有两个难点,一个是聚合是在binary mask上做的而不是在model parameters上做的,第二是这些binary mask是异构的而不是拥有同样的网络结构。

FEDMASK DESIGN(FedMask实际的设计)

全文和MobiCom21那篇不一样的地方应该主要是这里

image-20221224152949878

  • each device learns a heterogeneous binary mask via the proposed one-shot pruning method ( 1 )
  • each device optimizes the binary mask with a structured sparsity regularization while freezing the parameters of local model ( 2 -a)
  • only the optimized binary masks are transmitted from the devices to the central server ( 2 -b)
  • The aggregation strategy is specifically designed such that only the elements that are overlapped across the binary masks of the devices are aggregated while keeping non-overlapping elements unchanged ( 3 -a)
  • the personalization of the binary masks is preserved and the updated binary masks will be sent back to each device ( 3 -b)
  • The above process ( 2 - 3 ) repeats until reaching a predefined number of communication rounds.
  • the binary mask will be elementwise applied to the frozen parameters to generate a personalized and structured sparse model ( 4 )
  • 和MobiCom21那篇不一样的地方是,那篇传送的是子网不是mask,这篇传送mask,且在本地只做mask的优化,其它的参数不变

Binary Mask Optimization(第一个挑战如何解决)

学习一个binary mask的同时冻结model parameters是FedMask技术的基础,下面论述。

不失一般性,以全连接为例(卷积层类似实现),bias项暂时忽略。

image-20221224160412547

  • 全连接层:$y=W\cdot x$, $y\in R^m$代表输出,$x\in R^{n}$代表输入,$W\in R^{m\times n}$代表权重矩阵。
  • 加上binary mask之后,也就是$m\in {\left{0,1\right}^{m\times n}}$和$W$的shape一样,带有Mask的全连接层:image-20221224160955140
  • 当前的优化算法(例如SGD)用在binary value上不可行,因此,引入一个real-valued mask即$m^r\in R^{m\times n}$来设计一个binary mask optimization,在feedforward step的时候,$m^r$被使用threshold binaried to m,如方程3所示:image-20221224161349004
  • 在back-propagation step,梯度m由方程4计算:image-20221224161714822
  • 尽管这样的策略能够实现$m^r$和$m$的优化,但可能会导致巨大的梯度尺度的变化,这会影响$m^r$的优化(见原文引文),为了减少这种gradient variance,加入sigmoid函数:image-20221224162103399
  • image-20221224162328426

实质上:就是加了个sigmoid函数,训练的时候没有特殊的处理,在推理的时候就用方程3操作一下。

One-Shot Pruning for Mask Initialization(第一个挑战如何解决)

经过Binary Mask Optimization我们知道了如何在binary matrix上进行反向传播,下面介绍如何进行剪枝,结构化剪枝而不是非结构化剪枝

有各种各样的方法去确定要剪去什么样的参数,例如threshold、kernel sparsity、entropy、filter importance等等,然而当前的方法都是为real-valued parameters设计的,因此不能够被直接用于剪枝binary mask。

在Binary Mask Optimization已经介绍了real-valued mask,因此一个naive的方法是直接将当前的剪枝的方法用到real-valued mask上。然而, 由方程4和7得知,real-valued的masks是直接由fixed weight来scale的,因此real-valued masks的值的大小不能作为pruning的依据,基于这样的观察,于是设计了one-shot pruning method,这个method基于real-valued的masks和the fixed weights,也就是$W\odot m^r$。

每个device在binary mask的top layers保存dense structure,在最后几层进行剪枝(也就是由分类的部分组成)。不用优化后的绝对值作为剪枝的依据,而使用$W\odot m^r$绝对值的变化作为剪枝的依据。定义剪枝率为$p_r$,剪枝的过程由两步组成:

  • (1)每个device一个epoch更新他们的real-valued masks
  • (2)通过对$W_{ij}\cdot m^r_{ij}$的值进行排序,选择最大的$p_r$比例的元素,其余的元素设为0并且冻结。然后在本地做local training。

作者认为第一个epoch更新的参数就是比较重要的参数,和W做一个点积可以知道哪些未知的参数比较重要,哪些位置参数不重要,把不重要的参数剪掉。

Local Binary Mask Optimization(第二个挑战如何解决)

在作为one-shot pruning之后,每个device就有了一个heterogeneous mask。训练的时候,冻结模型的参数,训练binary的mask,且只训练排名靠前的binary的mask。

模型的Loss:

为了提升设备上的计算的精度,在mask optimization的过程中使用结构化的稀疏正则化方法去学习binary masks with structured sparsity。目标是在卷积层获得channel-wise和filter-wise的sparsity,在全连接层获得row-wise和column-wise的sparsity。下面的公式和MobiCom21的基本一样:

image-20221224194914540

image-20221224194935059

方案的简单测试:

image-20221224195622039

  • LSTM的效果不好,为什么?

  • LSTM的公式如下所示:image-20221224195653878

  • 加了mask的LSTM的公式如下所示:image-20221224195738865

  • 原因:尽管masked based的CNN、MLP的performance drop都很小,但是mask based的CNN和MLP的表达能力确实受到了限制,这种表达能力的衰减在LSTM中表现得更厉害,因为这种表达能力的衰减会通过方程12中的nested mask structure展现出来,这种不断累加的表达能力的衰减会导致LSTM严重的performance drop。定义real-valued unit的表达能力为$\mu$,一个masked unit的表达衰减是$\epsilon<1$,在一次feedforward的过程中,mask based的LSTM会衰减到$(1-\epsilon)^3{\mu}^3$,正如图6中展示的那样,从$h_{t-1}$到$h_t$的时候,表达能力衰减的累加为$1-(1-\epsilon)^3$:image-20221224200539110

  • 为了减少这种衰减,就移除了$W_o$,$W_g$和$W_i$的binary mask,这样的话,表达能力的decay就只有$1-(1-\epsilon)^2$,只增加了25%的额外通信开销。

Aggregate Heterogeneous Binary Masks(第三个挑战如何解决)

对至少出现两次的elements做averging,其它的不变。

image-20221224201552518

一个例子:

image-20221224201607117

  • 这张图和MobiCom那张基本一样。

完整算法

image-20221224201656905

EVALUATION(评测)

和MobiCom21的那篇实验设置很类似。

image-20221224201820023

image-20221224201827422

image-20221224201835911

image-20221224201846183

image-20221224201856172

image-20221224201905717

Prev:
arXiv22-FEDNAS: FEDERATED DEEP LEARNING VIA NEURAL ARCHITECTURE SEARCH
Next:
MobiCom21-Hermes: An Efficient Federated Learning Framework for Heterogeneous Mobile
catalog
catalog