在本教程中,我將實現聯邦學習(FL)的構建塊,並從頭開始對MNIST機器學習數據集進行訓練。
介紹
FL架構的基本形式包括一個位於中心的管理員或伺服器,負責協調訓練活動。客戶端主要是邊緣設備,可以達到數百萬的數量。這些設備在每次訓練迭代中至少與伺服器通信兩次。首先,它們各自從伺服器接收當前全局模型的權重,在各自的本地數據上對其進行訓練,以生成更新後的參數,然後將這些參數上傳到伺服器進行匯總。這種通信循環一直持續到達到預先設定的epochs數或準確度條件為止。在聯邦平均算法中,匯總僅僅意味著平均操作。
現在我們已經了解了FL是什麼以及它如何工作,讓我們繼續在Tensorflow中從頭開始構建一個FL,並在Kaggle的MNIST數據集上(kaggle.com/scolianni/mnistasjpg)對其進行訓練。
導入所需的Python庫
讀取和預處理MNIST數據集
我正在使用MNIST數據集的jpeg版本。它由42000個數字圖像組成,每個類保存在單獨的文件夾中。我將使用一下Python代碼片段將數據加載到內存中,並保留10%的數據,以便稍後測試經過訓練的全局模型。
在第9行,從磁碟讀取每個圖像作為灰度,然後將其flattened。flattening步驟是導入,因為稍後我們將使用MLP網絡架構。然後我們分割圖像的路徑,在第11行獲得它的類ID。我們還在第13行上的將圖像縮放到[0,1],以減弱像素亮度變化的影響。
訓練和測試拆分
在這個代碼片段中執行了兩個步驟。我們應用了在前面的代碼塊中定義的load函數來獲得圖像列表(現在是numpy數組)和標籤列表。之後,我們使用來自sklearn的LabelBinarizer對象對標籤進行one-hot編碼。不再把數字1的標籤作為數字1,它現在的形式是[0,1,0,0,0,0,0,0,0,0,0]。使用這種標記樣式,我將能夠使用Tensorflow中的交叉熵損失作為我們模型的損失函數。或者,我可以保持標籤不變,而使用稀疏分類熵損失。最後,我使用sklearn的train_test_split將數據拆分成比例為9:1的train/test。
聯邦成員(客戶端)
在FL的實際實現中,每個聯邦成員將獨立擁有自己的數據。請記住,FL的目標是將模型傳遞到數據。我將訓練集分成10個碎片,每個客戶一個。我寫了一個函數create_clients來實現這一目標。
在第13行,我使用前綴字符串創建了一個客戶端名稱列表。第16–21行將數據和標籤壓縮,將所得的元組列表隨機化並分片為所需數量的客戶端(num_clients)。在第26行,返回了一個字典,其中包含作為鍵的每個客戶端名稱和作為值的它們的數據共享。現在讓我們將這個函數應用到我們的訓練數據集。
clients = create_clients(X_train, y_train, num_clients=10, initial='client')
批處理客戶端和測試數據
接下來是將每個客戶端數據處理為tensorflow數據集並進行批處理。為了簡化這個步驟並避免重複,我將這個過程封裝到一個名為batch_data的小函數中。
每個客戶端數據集都是以create_clients中的數據/標籤元組列表的形式出現的。在上面的第9行,為了與TFDS API兼容,我將元組拆分為單獨的數據和標籤列表。在應用此函數時,我還將處理測試集,並將其保留到以後使用。
創建模型
在介紹部分我沒有提到的一件事是FL最適合參數化學習——所有類型的神經網絡。諸如KNN或類似的機器學習技術僅存儲訓練數據,而學習可能無法從FL中受益。我正在創建一個2層MLP作為分類任務的模型。我將使用Keras API創建此文件。
要構建新模型,將調用build方法。它需要輸入數據的形狀和類的數量作為參數。使用MNIST,shape參數將是28*28*1 = 784,而類的數量將是10。此時,我還將為模型編譯定義一個優化器、損失函數和度量。
SGD是默認優化器。損失函數為categorical_crossentropy,度量為accuracy。但是,在衰變參數中看起來有些奇怪。comms_round是什麼?它只是我想要運行的全局epochs(aggregations)數量。
模型匯總(加權平均)
到目前為止,根據深度學習管道,我所做的一切幾乎都是標準的(除了數據分區和客戶端創建)。我所使用的數據是水平分區的,因此我將簡單地進行組件級參數平均,並根據每個參與客戶端貢獻的數據點的比例進行加權。這是我用的聯邦平均方程
在右側,我們根據單個客戶端持有的每個數據點上記錄的損失值來估計權重參數。在左邊,我們縮放了客戶的參數並對結果求和。
下面我將這個過程封裝為三個簡單的函數。
(1)weight_scalling_factor 計算客戶的本地訓練數據在所有客戶持有的總體訓練數據中所佔的比例。首先,我們估計客戶的批次大小,然後使用它來計算自己的數據點數量。然後,我們獲得了第6行上的總體全局訓練數據大小。最後,我們在第9行以分數的形式計算了比例因子。這當然不可能是實際應用程式中的方法。任何客戶都不能訪問合併的訓練數據。在這種情況下,在每個本地訓練步驟之後用新參數更新伺服器時,每個客戶機都應該指出它們所持有的數據點的數量。
(2)scale_model_weights根據(1)中計算的比例因子的值來縮放每個局部模型的權重
(3)sum_scaled_weights將所有客戶的比例權重加在一起。
聯邦模型訓練
訓練邏輯有兩個主循環,外循環用於全局迭代,內循環用於迭代每個客戶端的本地訓練。
首先構建全局模型,輸入形狀為(784),數字類為10。然後我進入了外循環。首先獲得全局模型的初始化權值。第15行和第16行隨機化了客戶端字典順序。然後開始遍歷客戶端。
對於每個客戶端,我初始化一個新的模型對象,編譯它,並將它的初始化權重設置為全局模型的當前參數。然後對局部模型(客戶端)進行一個epoch的訓練。在訓練之後,新的權重將被縮放並附加到scaled_local_weight_list中。
回到第41行的外循環,我獲取了所有縮放後的局部訓練權重的總和,並將全局模型更新為這個新的匯總。這樣就結束了完整的全局訓練epoch。按照前面聲明的comms_round參數的規定,我運行了100個全局訓練循環。
最後在第48行,我使用預留的測試集,在每一輪通信結束後,對訓練好的全局模型進行測試,代碼如下:
結果
測試結果有10個客戶端,每個客戶端運行1個本地epoch,並進行100次全局通信。
與SGD比較
FL模型測試結果很好,經過100輪通信後,測試準確率達到了96.5%。但它與在相同數據集上訓練的標準SGD模型相比如何呢?我將在聯邦訓練數據上訓練一個模型(而不是像在FL中那樣訓練10個模型)。為此,我將使用分區之前的預處理訓練數據來訓練完全相同的2層MLP模型。
為了確保一個公平,我將保留用於FL訓練的每個超級參數,但batch size除外。不是使用32,我們的SGD模型的batch size將是320。
在100個epoch之後,SGD模型的測試精度達到了94.5%。在這個數據集上,FL的表現比SGD要好一點,不過這種結果在現實世界中是不可能出現的。客戶端持有的真實聯邦數據大多是非獨立同分布IID)的數據。例如,如果我們根據訓練數據集構造客戶機碎片,使每個客戶機的shad由單個類組成,比如client_1隻有數字1的圖像,client_2隻有數字2的圖像,等等,我們就可以模擬這個場景。如果採用這種非IID安排,我們的FL模型的測試準確率可能會下降6%。