這裡是 「王喆的機器學習筆記」 的第二十六篇文章。這篇文章我們繼續討論機器學習模型的分布式訓練問題。
上篇文章對Spark MLlib的並行訓練方法做了詳細的介紹(分布式機器學習之——Spark MLlib並行訓練原理),如文章所說,Spark採取了簡單直觀的數據並行的方法解決模型並行訓練的問題,但由於Spark的並行梯度下降方法是同步阻斷式的,且模型參數需通過全局廣播的形式發送到各節點,因此Spark的並行梯度下降是相對低效的。
為了解決相應的問題,2014年分布式可擴展的Parameter Server被 沐神 李沐 提出,幾乎完美的解決了機器模型的分布式訓練問題,時至今日,parameter server不僅被直接應用在各大公司的機器學習平臺上,而且也被集成在TensorFlow,MXNet等主流的深度框架中,作為機器學習分布式訓練最重要的解決方案。
第一部分我們首先聚焦PS進行分布式訓練的基本原理。這裡以通用的機器學習問題為例。
上式是一個通用的帶正則化項的損失函數,其中n是樣本總數,l(x,y,w)是計算單個樣本的損失函數,x是特徵向量,y是樣本label,w是模型參數。那麼模型的訓練目標就是使損失函數F(w)最小。為了求解arg (min F(w)),往往使用梯度下降的方法,那麼Parameter Server的主要目的就是分布式並行進行梯度下降的計算完成參數的更新與最終收斂。需要注意的是,由於公式中正則化項的存在需要匯總所有模型參數才能夠正確計算,因此較難進行模型參數的並行訓練,因此Parameter Server採取了和Spark MLlib一樣的數據並行訓練產生局部梯度,再匯總梯度更新參數權重的並行化訓練方案。
具體來講,圖1以偽碼方式列出了Parameter Server並行梯度下降的主要步驟:
可以看到Parameter Server由server節點和worker節點組成,其主要功能分別如下:
server節點的主要功能是保存模型參數、接受worker節點計算出的局部梯度、匯總計算全局梯度,並更新模型參數
worker節點的主要功能是各保存部分訓練數據,從server節點拉取最新的模型參數,根據訓練數據計算局部梯度,上傳給server節點。
在物理架構上,PS其實是和spark的master-worker的架構基本一致的,具體如圖2
可以看到,PS分為兩大部分:server group和多個worker group,另外resource manager負責總體的資源分配調度。
server group內部包含多個server node,每個server node負責維護一部分參數,server manager負責維護和分配server資源;
每個worker group對應一個application(即一個模型訓練任務),worker group之間,以及worker group內部的worker node互相之間並不通信,worker node只與server通信。
結合PS的物理架構,PS的並行訓練整體示意圖如圖3:
圖3結合圖2描述的並行梯度下降方法的偽碼以及圖2的PS物理架構,清晰的描述了PS的並行梯度下降流程,其中最關鍵的兩個操作就是push和pull:
push:worker節點利用本節點上的訓練數據,計算好局部梯度,上傳給server節點;
pull:為了進行下一輪的梯度計算,worker節點從server節點拉取最新的模型參數到本地。
1.每個worker載入一部分訓練數據
2.worker節點從server節點pull最新的全部模型參數
3.worker節點利用本節點數據計算梯度
4.worker節點將梯度push到server節點
5.server節點匯總梯度更新模型
6.goto step2 直到迭代次數上限或模型收斂
在上篇文章介紹spark的並行梯度下降原理時,曾經提到spark並行梯度下降效率較低的原因就是每個節點都需要等待其他所有節點的梯度都計算完後,master節點匯總梯度,計算好新的模型參數後,才能開始下一輪的梯度計算,我們稱這種方式為「同步阻斷式」的並行梯度下降過程。
「同步阻斷式「的並行梯度下降雖然是嚴格意義上的一致性最強的梯度下降方法,因為其計算結果和串行計算的過程一直,但效率過低,各節點的waiting時間過長,有沒有辦法提高梯度下降的並行度呢?
PS採取的方法是用「異步非阻斷式」的梯度下降替代原來的同步式方法。圖4是一個worker節點多次迭代計算梯度的過程,可以看到節點在做第11次迭代(iter 11)計算時,第10次迭代後的push&pull過程並沒有結束,也就是說最新的模型權重參數還沒有被拉取到本地,該節點仍使用的是iter 10的權重參數計算的iter 11的梯度。這就是所謂的異步非阻斷式梯度下降方法,其他節點計算梯度的進度不會影響本節點的梯度計算。所有節點始終都在並行工作,不會被其他節點阻斷。
用下面轉載了兩個異步更新和同步更新的動畫,大家可以非常直觀的了解異步更新和同步更新的過程和區別。
當然,任何的技術方案都是取捨,異步梯度更新的方式雖然大幅加快了訓練速度,但帶來的是模型一致性的喪失,也就是說並行訓練的結果與原來的單點串行訓練的結果是不一致的,這樣的不一致會對模型收斂的速度造成一定影響。所以最終選取同步更新還是異步更新取決於不同模型對於一致性的敏感程度。這類似於一個模型超參數選取的問題,需要針對具體問題進行具體的驗證。
除此之外,在同步和異步之間,還可以通過一些「最大延遲」等參數來限制異步的程度。比如可以限定在三輪迭代之內,模型參數必須更新一次,那麼如果某worker節點計算了三輪梯度,該節點還未完成一次從server節點pull最新模型參數的過程,那麼該worker節點就必須停下等待pull操作的完成。這是同步和異步之間的折衷方法。
在PS論文的原文中也提供了異步和同步更新的效率對比,這裡可以作為參考(基於Sparse logistic regression模型訓練)。
SystemA和B都是同步更新梯度的系統,PS是異步更新的策略,可以看到PS的computing佔比遠高於同步更新策略
可以看到異步更新的PS的收斂速度也遠勝於同步更新的SystemA和B,這證明異步更新帶來的梯度不一致性的影響沒有想像中那麼大
導致Spark MLlib並行訓練效率低下的另一原因是每次迭代都需要master節點將模型權重參數的廣播發送到各worker節點。這導致兩個問題:
1.master節點作為一個瓶頸節點,受帶寬條件的制約,發送全部模型參數的效率不高;
2.同步地廣播發送所有權重參數,使系統整體的網絡負載非常大。
那麼PS是如何解決單點master效率低下的問題呢?從圖2的架構圖中可知,PS採用了server group內多server的架構,每個server主要負責一部分的模型參數。模型參數使用key value的形式,每個server負責一個key的range就可以了。
那麼另一個問題來了,每個server是如何決定自己負責哪部分key range呢?如果有新的server節點加入,又是如何在保證已有key range不發生大的變化的情況下加入新的節點呢?這兩個問題的答案涉及到一致性哈希(consistent hashing)的原理。
PS的server group中應用一致性哈希的原理大致有如下幾步:
1.將模型參數的key映射到一個環形的hash空間,比如有一個hash函數可以將任意key映射到0~(2^32)-1的hash空間內,我們只要讓(2^32)-1這個桶的下一個桶是0這個桶,那麼這個空間就變成了一個環形hash空間;
2.根據server節點的數量n,將環形hash空間等分成n*m個range,讓每個server間隔地分配m個hash range。這樣做的目的是保證一定的負載均衡性,避免hash值過於集中帶來的server負載不均;
3.在新加入一個server節點時,讓新加入的server節點找到hash環上的插入點,讓新的server負責插入點到下一個插入點之間的hash range,這樣做相當於把原來的某段hash range分成兩份,新的節點負責後半段,原來的節點負責前半段。這樣不會影響其他hash range的hash分配,自然不存在大量的rehash帶來的數據大混洗的問題。
4.刪除一個server節點時,移除該節點相關的插入點,讓臨近節點負責該節點的hash range。
PS server group中應用一致性哈希原理,其實非常有效的降低了原來單master節點帶來的瓶頸問題。比如現在某worker節點希望pull新的模型參數到本地,worker節點將發送不同的range pull到不同的server節點,server節點可以並行的發送自己負責的weight到worker節點。
此外,由於在處理梯度的過程中server節點之間也可以高效協同,某worker節點在計算好自己的梯度後,也只需要利用range push把梯度發送給一部分相關的server節點即可。當然,這一過程也與模型結構相關,需要跟模型本身的實現結合起來實現。總的來說,PS基於一致性哈希提供了range pull和range push的能力,讓模型並行訓練的實現更加靈活。
總結一下Parameter Server實現分布式機器學習模型訓練的要點:
1.用異步非阻斷式的分布式梯度下降策略替代同步阻斷式的梯度下降策略;
2.實現多server節點的架構,避免了單master節點帶來的帶寬瓶頸和內存瓶頸;
3.使用一致性哈希,range pull和range push等工程手段實現信息的最小傳遞,避免廣播操作帶來的全局性網絡阻塞和帶寬浪費。
但是大家要清楚的是,Parameter Server僅僅是一個管理並行訓練梯度的權重的平臺,並不涉及到具體的模型實現,因此PS往往是作為MXNet,TensorFlow的一個組件,要想具體實現一個機器學習模型,還需要依賴於通用的,綜合性的機器學習平臺。那麼下一篇文章,我們就來介紹一下以TensorFlow為代表的機器學習平臺的工作原理,特別是並行訓練的原理。
又到了大家能學到最多的問題時間,歡迎積極討論,分享業界經驗:
1.Parameter Server有哪些工程實現,大家在業界成功應用的Parameter Server的開源項目有哪些?
2.Parameter Server在離線訓練完成後,能否直接應用於線上inference,大家有沒有成功的經驗?
認為文章有價值的同學,歡迎關注我的 微信公眾號:王喆的機器學習筆記(wangzhenotes),跟蹤計算廣告、推薦系統等機器學習領域前沿。
想進一步交流的同學也可以通過公眾號加我的微信一同探討技術問題,謝謝