SysML 2019 | Priority-based Parameter Propagation for Distributed DNN Training

本文提出了一种基于优先级的参数同步方法,通过该方法可以更高效地将计算和通信进行重叠,从而加快深度神经网络(Deep Neural Network, DNN)的训练速度。

在DNN模型的单机训练中,每次迭代时,数据从第一层开始逐层传递,进行前向传播计算loss;然后从最后一层开始计算梯度,将梯度反向传播回第一层,在反向传播的同时使用优化算法(如SGD)更新参数,这样就完成了一次模型更新,紧接着开始下一次迭代。只有当本次迭代中第一层的参数更新完成后,下一次迭代才能开始,两次迭代间的关系如图1所示。

DNN模型的分布式训练相较于单机训练多了参数同步的操作。传统的参数计算和同步过程如图4(a)所示,假设有一个三层的神经网络,在本次迭代的反向传播阶段,我们首先计算第一层的梯度,当第三层的梯度计算完成后,就可以开始相关的参数同步操作,同时计算第二层的梯度。当第二层的梯度计算完成后,需要等待第三层的梯度同步完成才能进行该层的梯度同步,因此第二层的梯度需要等待,与此同时计算第一层的梯度。第一层梯度计算完成的同时第三层梯度同步完成,这时会首先同步第二层的梯度,完成后再同步第一层的梯度。第一层梯度同步完毕之后才能开始下一次迭代。

前面提到,本次迭代第一层反向传播结束后才能开始下一次迭代的前向传播,因此,第一层反向传播结束的越早,下一次迭代开始的时间越早,整体的训练时间就会缩短。因此,本文提出了基于优先级的参数同步方法,计算和同步过程如图4(b)所示。在参数同步时,一旦第一层的梯度计算完成,那么就会立即进行该层的参数同步,目的就是为了尽早的完成第一层的反向传播。通过对比可以看出,本例中基于优先级的参数同步方法将两次迭代之间的延迟减少了2。

参数同步的过程包括以下三部分:

  • worker向server发送梯度
  • server更新参数
  • server向worker传回参数

现存的深度学习框架对参数同步过程做了一些并行化,即worker向server发送当前层的梯度时,server在进行上一层梯度的参数更新。但是这种优化只适合DNN各层参数数量相近的情况,如果DNN某一层的参数非常多(例如VGG-19的某个全连接层包含了整个模型71.5%的参数),那么这种粒度的并行化效果就会打折扣。如图6(a)所示,假设三层神经网络中第二层参数比较多,那么该层的参数同步时间就会很长,从而拖慢整体的训练速度。为此,本文提出了参数切片(parameter slicing),具体思路就是把较大的参数矩阵分割成很多小的参数矩阵,使参数同步的每一步消耗的时间都类似,从而更好地利用流水线进行并行化。在本例中,通过参数切片,整体的参数同步时间从10减少到了7。

P3: Design and Implementation

P3的核心由两部分组成:参数切片和基于优先级的调度。P3会把每一层的参数分割成很多切片,这些参数切片会独立地更新和同步。完成分割后,P3会基于DNN中层的顺序给每个切片一个优先级。一般来说,优先级从前向后依次降低。

KVStore

KVstore通过一种启发式的方式确保全局参数均匀地分布在KVServer中。如果某一层的参数大小小于一个固定的阈值,那么这层参数会被分配到一个随机的KVServer上;如果大于固定阈值,那么这层的参数会被均匀地分割到所有的KVServer上,阈值的默认值是$10^6$。在训练过程中,一旦某一层的反向传播完成,MXNet就通过KVWorker向KVServer发出该层的参数同步请求。KVWorker将梯度矩阵序列化并向相应的KVServer发出推送请求,KVServer会在收到所有KVWorker推送的梯度后进行梯度聚合并更新全局模型。全局的参数更新完毕后,KVServer会通知KVWorker。KVWorker会在收到通知后立即向KVServer发送拉取参数的请求。随后KVServer会将最新的模型参数发送给KVWorker,KVWorker利用这些参数来更新本地模型,进行下一次迭代。MXNet在推送某一层梯度的同时会反向传播计算上一层的梯度,从而将参数同步与反向传播重叠。

P3

P3在实现上把KVWorker和KVServer替换成P3Worker和P3Server。P3Worker通过预定义的阈值实现参数切片,与KVWorker不同的是,这个阈值定义了参数切片的最大粒度。分割后的参数矩阵通过Round-Robin算法发送给P3Server。P3使用基于生产者——消费者模式的优先级队列实现基于优先级的梯度调度。在参数切片之后,P3Worker的生产者部分为各个切片分配优先级,并将它们一次性推送到优先级队列中。P3Worker的消费者线程轮询队列中最高优先级的切片,并通过网络将参数切片和优先级发送到P3Server。消费者线程的网络调用是阻塞式的,因此它会基于网络延迟自动调整轮询优先级队列的速率。生产者——消费者模型确保网络不会承受来自P3Worker的突发流量,并且worker端的反向传播也不会受到影响。

P3在server端也实现了生产者——消费者模式,以减少网络内的延迟。P3Server接收到的数据包会根据P3Worker指定的key进入到优先级队列中。然后,服务器消费者线程将从优先队列轮询并以与KVServer中相同的方式处理数据包。

除了上面提到的之外,P3还移除了全局参数更新完成后KVServer对KVWorker的显式通知。参数更新完成后,P3Server会立即向P3Worker广播所有的参数。通过移除这种显式的通知,P3可以提高网络带宽的利用率。