為諾亞方舟實驗室聯合雪梨大學發布論文《Kernel Based Progressive Distillation for Adder Neural Networks》,提出了針對加法神經網絡的蒸餾技術,ResNet-34和ResNet-50網絡在ImageNet上分別達到了68.8%和76.8%的準確率,效果與相同結構的CNN相比持平或超越,該論文已被NeurIPS2020接收。
論文連結:https://arxiv.org/pdf/2009.13044.pdf研究背景深度卷積神經網絡(CNN)被廣泛應用於諸多計算機視覺領域的實際任務中(例如,圖片分類、物體檢測、語義分割等)。然而,為了保證性能,神經網絡通常是過參數化的,因此會存在大量的冗餘參數。近期提出的加法神經網絡(ANN),通過將卷積操作中的距離度量函數替換為L1距離,極大減少了神經網絡中的乘法操作,從而減少了網絡運行所需的功耗和晶片面積。然而,ANN在準確率方面和同結構的CNN相比仍然有一定差距,在某種程度上限制了ANN在實際應用中對CNN的替換。為了提高ANN的性能,我們提出了一種基於核的漸進蒸餾方法。具體的,我們發現一個訓練好的ANN網絡其參數通常服從拉普拉斯分布,而一個訓練好的CNN網絡其參數通常服從高斯分布。因此,我們對網絡中間層的特徵圖輸出進行核變換後,使用距離度量函數估計教師網絡(CNN)和學生網絡(ANN)之間的損失。對於最後一層,我們使用傳統的KL散度估計兩個網絡之間的損失。同時,在訓練中我們使用隨機初始化的教師網絡,與學生網絡同時訓練,以減少兩個網絡之間參數分布的差異性。實驗表明,我們的算法得到的ANN能夠在CIFAR-10,CIFAR-100,ImageNet等標準圖片分類數據集上達到或超越同結構CNN的準確率。對網絡中間層特徵圖輸出進行核交換ANN本身精度不好的原因是原始ANN在反向傳播時,使用的是近似的梯度,導致目標函數無法向著最小的方向移動。傳統KD方法應用到ANN上效果不佳的原因,在於ANN的權重分布是拉普拉斯分布,而CNN的權重分布為高斯分布,因此分布不同導致無法直接對中間層的featuremap使用KD方法。本方法首先將核變換作用於教師網絡和學生網絡的中間層輸出,並使用1x1卷積對新的輸出進行配準。之後,結合最後一層的蒸餾損失與分類損失,得到整體的損失函數。具體的,給定ANN的第m層輸入和權重,以及CNN第m層的輸入和權重,他們的輸出分別為與。其中,和的定義分別為:之後對ANN和CNN的輸出分別進行拉普拉斯核變換和高斯核變換:得到核變換後的輸出,其中和是可學習的參數。之後,對核變換後的輸出分別過1x1的卷積層,得到最後的中間層輸出:其中,代表是1x1卷積操作。和為卷積操作中的參數。最後,我們對輸出y求MSE loss,使得ANN的中間層輸出學習CNN。即:除了上述中間層的loss之外,我們還希望ANN學習CNN最後一層的輸出,以及ANN關於目標任務的loss(這裡以分類任務舉例)。對於ANN的最後一層,在分類任務上輸出的是每一個類別的概率分布,因此希望它同時學習CNN的概率分布,以及ground-truth的概率分布(ground-truth為one-hot vector),因此構造的loss function為:將該loss與之前的loss結合,就得到最終的目標方程:
漸進式蒸餾算法
傳統的蒸餾方法使用固定的,訓練好的教師網絡來教學生網絡。這樣做會帶來問題。由於教師網絡和學生網絡處於不同的訓練階段,因此他們的分布會因為訓練階段的不同而不同,所以會導致KD方法效果不好。因此我們採用漸進式蒸餾方法,讓教師網絡和學生網絡共同學習,有助於KD方法得到好的結果。即目標函數變為:
其中b為當前的step。
實驗結果我們在CIFAR-10、CIFAR-100、ImageNet三個數據集上分別進行了實驗。下表是在CIFAR-10和CIFAR-100數據集上的結果,我們使用了VGG-small、ResNet-20與ResNet-32作為教師網絡,同結構的ANN作為學生網絡。可以看到,使用了本方法得到的ANN在分類準確率上相比原始的ANN有大幅度的提升,並且能夠超過同結構的CNN模型。表格中#Mul表示網絡中乘法操作的次數。#Add表示加法操作次數,#XNOR表示同或操作的次數。
下表展示了在ImageNet數據集上的結果,我們使用ResNet-18與ResNet-50網絡作為教師網絡,同結構的ANN作為學生網絡。結果顯示我們的方法得到的ANN在分類準確率上相比同結構CNN基本相同或能夠超越。
最後,我們展示了ResNet-20,ANN-20與通過本方法得到的PKKD ANN-20模型在CIFAR-10與CIFAR-100數據集上的訓練精度曲線與測試精度曲線。
圖中的實線表示訓練精度,虛線表示測試精度。在兩個數據集中,CNN的訓練和測試準確率都超過了原始的ANN模型。這是因為在訓練原始ANN時,反向傳播的梯度使用的是L2 norm來近似,因此梯度方向是不準確的。當使用本方法後,CNN的訓練過程可以指導ANN的訓練,因此可以得到更好的結果。同時,知識蒸餾方法能夠幫助學生網絡防止過擬合,這也是我們的方法有最低的訓練精度和最高的測試精度的原因。
— 完 —