向AI轉型的程式設計師都關注了這個號👇👇👇
機器學習AI算法工程 公眾號:datayx
最近用keras跑基於resnet50,inception3的一些遷移學習的實驗,遇到一些問題。通過查看github和博客發現是由於BN層導致的,國外已經有人總結並提了一個PR(雖然並沒有被merge到Keras官方庫中),並寫了一篇博客,也看到知乎有人翻譯了一遍:Keras的BN你真的凍結對了嗎
https://zhuanlan.zhihu.com/p/56225304
當保存模型後再加載模型去預測時發現與直接預測結果不一致也可能是BN層的問題。
總結:
顯式設置
不可否認的是,默認的Frozen的BN的行為在遷移學習中確實是有training這個坑存在的,個人認為fchollet的修複方法更簡單一點,並且這種方式達到的效果和使用預訓練網絡提取特徵,單獨訓練分類層達到的效果是一致的,當你真的想要凍結BN層的時候,這種方式更符合凍結的這個動機;但在測試時使用新數據集的移動均值和方差一定程度上也是一種domain adaption。
譯文:
雖然Keras節省了我們很多編碼時間,但Keras中BN層的默認行為非常怪異,坑了我(此處及後續的「我」均指原文作者)很多次。Keras的默認行為隨著時間發生過許多的變化,但仍然有很多問題以至於現在Keras的GitHub上還掛著幾個相關的issue。在這篇文章中,我會構建一個案例來說明為什麼Keras的BN層對遷移學習並不友好,並給出對Keras BN層的一個修復補丁,以及修復後的實驗效果。
這一節我會簡要介紹遷移學習和BN層,以及learning_phase的工作原理,Keras BN層在各個版本中的變化。如果你已經了解過這些知識,可以直接跳到第二節(譯者註:1.3和1.4跟這個問題還是比較相關的,不全是背景)。
1.1 遷移學習在深度學習中非常重要深度學習在過去廣受詬病,原因之一就是它需要太多的訓練數據了。解決這個限制的方法之一就是遷移學習。
假設你現在要訓練一個分類器來解決貓狗二分類問題,其實並不需要幾百萬張貓貓狗狗的圖片。你可以只對預訓練模型頂部的幾層卷積層進行微調。因為預訓練模型是用圖像數據訓練的,底層卷積層可以識別線條,邊緣或者其他有用的模式作為特徵使用,所以可以用預訓練模型的權重作為一個很好的初始化值,或者只對模型的一部分用自己數據進行訓練。
Keras包含多種預訓練模型,並且很容易Fine-tune,更多細節可以查閱Keras官方文檔。
1.2 Batch Normalization是個啥BN在2014年由Loffe和Szegedy提出,通過將前一層的輸出進行標準化解決梯度消失問題,並減小了訓練達到收斂所需的迭代次數,從而減少訓練時間,使得訓練更深的網絡成為可能。具體原理請看原論文,簡單來說,BN將每一層的輸入減去其在Batch中的均值,除以它的標準差,得到標準化的輸入,此外,BN也會為每個單元學習兩個因子來還原輸入。從下圖可以看到加了BN之後Loss下降更快,最後能達到的效果也更好。
1.3 Keras中的learning_phase是啥網絡中有些層在訓練時和推導時的行為是不同的。最重要的兩個例子就是BN和Dropout層。對BN層,訓練時我們需要用mini batch的均值和方差來縮放輸入。在推導時,我們用訓練時統計到的累計均值和方差對推導的mini batch進行縮放。
Keras用learning_phase機制來告訴模型當前的所處的模式。假如用戶沒有手工指定的話,使用fit()時,網絡默認將learning_phase設為1,表示訓練模式。在預測時,比如調用predict()和evaluate()方法或者在fit()的驗證步驟中,網絡將learning_phase設為0,表示測試模式。用戶可以靜態地,在model或tensor添加到一個graph中之前,將learning_phase設為某個值(雖然官方不推薦手動設置),設置後,learning_phase就不可以修改了。
1.4 不同版本中的Keras是如何實現BN的Keras中的BN訓練時統計當前Batch的均值和方差進行歸一化,並且使用移動平均法累計均值和方差,給測試集用於歸一化。
Keras中BN的行為變過幾次,但最重要的變更發生在2.1.3這個版本。2.1.3之前,當BN被凍結時(trainable=False),它仍然會更新mini batch的移動均值和方差,並用於測試,造成用戶的困擾(一副沒有凍結住的樣子)。
這種設計是錯誤的。考慮Conv1-Bn-Conv2-Conv3這樣的結構,如果BN層被凍結住了,應該無事發生才對。當Conv2處於凍結狀態時,如果我們部分更新了BN,那麼Conv2不能適應更新過的mini-batch的移動均值和方差,導致錯誤率上升。
在2.1.3及之後,當BN層被設為trainable=False時,Keras中不再更新mini batch的移動均值和方差,測試時使用的是預訓練模型中的移動均值和方差,從而達到凍結的效果, But is that enough? Not if you are using Transfer Learning.
2. 問題描述與解決方案我會介紹問題的根源以及解決方案(一個Keras補丁)的技術實現。同時我也會提供一些樣例來說明打補丁前後模型的準確率變化。
2.1 問題描述2.1.3版本後,當Keras中BN層凍結時,在訓練中會用mini batch的均值和方差統計值以執行歸一化。我認為更好的方式應該是使用訓練中得到的移動均值和方差(譯者註:這樣不就退回2.1.3之前的做法了)。原因和2.1.3的修復原因相同,由於凍結的BN的後續層沒有得到正確的訓練,使用mini batch的均值和方差統計值會導致較差的結果。
假設你沒有足夠的數據訓練一個視覺模型,你準備用一個預訓練Keras模型來Fine-tune。但你沒法保證新數據集在每一層的均值和方差與舊數據集的統計值的相似性。注意哦,在當前的版本中,不管你的BN有沒有凍結,訓練時都會用mini-batch的均值和方差統計值進行批歸一化,而在測試時你也會用移動均值方差進行歸一化。因此,如果你凍結了底層並微調頂層,頂層均值和方差會偏向新數據集,而推導時,底層會使用舊數據集的統計值進行歸一化,導致頂層接收到不同程度的歸一化的數據。
如上圖所示,假設我們從Conv K+1層開始微調模型,凍結左邊1到k層。訓練中,1到K層中的BN層會用訓練集的mini batch統計值來做歸一化,然而,由於每個BN的均值和方差與舊數據集不一定接近,在Relu處的丟棄的數據量與舊數據集會有很大區別,導致後續K+1層接收到的輸入和舊數據集的輸入範圍差別很大,後續K+1層的初始權重不能恰當處理這種輸入,導致精度下降。儘管網絡在訓練中可以通過對K+1層的權重調節來適應這種變化,但在測試模式下,Keras會用預訓練數據集的均值和方差,改變K+1層的輸入分布,導致較差的結果。
2.2 如何檢查你是否受到了這個問題的影響分別將learning_phase這個變量設置為1或0進行預測,如果結果有顯著的差別,說明你中招了。不過learning_phase這個參數通常不建議手工指定,learning_phase不會改變已經編譯後的模型的狀態,所以最好是新建一個乾淨的session,在定義graph中的變量之前指定learning_phase。
檢查AUC和ACC,如果acc只有50%但auc接近1(並且測試和訓練表現有明顯不同),很可能是BN迷之縮放的鍋。類似的,在回歸問題上你可以比較MSE和Spearman『s correlation來檢查。
2.3 如何修復如果BN在測試時真的鎖住了,這個問題就能真正解決。實現上,需要用trainable這個標籤來真正控制BN的行為,而不僅是用learning_phase來控制。具體實現在GitHub上。
主要是通過安裝補丁:作者提供了三個版本的補丁,安裝自己需要的版本就可以
用了這個補丁之後,BN凍結後,在訓練時它不會使用mini batch均值方差統計值進行歸一化,而會使用在訓練中學習到的統計值,避免歸一化的突變導致準確率的下降**。如果BN沒有凍結,它也會繼續使用訓練集中得到的統計值。**
原文:
By applying the above fix, when a BN layer is frozen it will no longer use the mini-batch statistics but instead use the ones learned during training. As a result, there will be no discrepancy between training and test modes which leads to increased accuracy. Obviously when the BN layer is not frozen, it will continue using the mini-batch statistics during training.
雖然這個補丁是最近才寫好的,但其中的思想已經在各種各樣的workaround中驗證過了。這些workaround包括:將模型分成兩部分,一部分凍結,一部分不凍結,凍結部分只過一遍提取特徵,訓練時只訓練不凍結的部分。為了增加說服力,我會給出一些例子來展示這個補丁的真實影響。
我會用一小塊數據來刻意過擬合模型,用相同的數據來訓練和驗證模型,那麼在訓練集和驗證集上都應該達到接近100%的準確率。
如果驗證的準確率低於訓練準確率,說明當前的BN實現在推導中是有問題的。
預處理在generator之外進行,因為keras2.1.5中有一個相關的bug,在2.1.6中修復了。
在推導時使用不同的learning_phase設置,如果兩種設置下準確率不同,說明確實中招了。
代碼如下:
輸出如下:
如上文所述,驗證集準確率確實要差一些。
訓練完成後,我們做了三個實驗,DYNAMIC LEARNING_PHASE是默認操作,由Keras內部機制動態決定learning_phase,static兩種是手工指定learning_phase,分為設為0和1.當learning_phase設為1時,驗證集的效果提升了,因為模型正是使用訓練集的均值和方差統計值來訓練的,而這些統計值與凍結的BN中存儲的值不同,凍結的BN中存儲的是預訓練數據集的均值和方差,不會在訓練中更新,會在測試中使用。這種BN的行為不一致性導致了推導時準確率下降。
加了補丁後的效果:
模型收斂得更快,改變learning_phase也不再影響模型的準確率了,因為現在BN都會使用訓練集的均值和方差進行歸一化。
2.5 這個修復在真實數據集上表現如何我們用Keras預訓練的ResNet50,在CIFAR10上開展實驗,只訓練分類層10個epoch,以及139層以後5個epoch。沒有用補丁的時候準確率為87.44%,用了之後準確率為92.36%,提升了5個點。
2.6 其他層是否也要做類似的修復呢?Dropout在訓練時和測試時的表現也不同,但Dropout是用來避免過擬合的,如果在訓練時也將其凍結在測試模式,Dropout就沒用了,所以Dropout被frozen時,我們還是讓它保持能夠隨機丟棄單元的現狀吧。
參考文獻:
https://zhuanlan.zhihu.com/p/56225304
http://blog.datumbox.com/the-batch-normalization-layer-of-keras-is-broken/
https://blog.csdn.net/wf592523813
閱讀過本文的人還看了以下文章:
TensorFlow 2.0深度學習案例實戰
基於40萬表格數據集TableBank,用MaskRCNN做表格檢測
《基於深度學習的自然語言處理》中/英PDF
Deep Learning 中文版初版-周志華團隊
【全套視頻課】最全的目標檢測算法系列講解,通俗易懂!
《美團機器學習實踐》_美團算法團隊.pdf
《深度學習入門:基於Python的理論與實現》高清中文PDF+源碼
特徵提取與圖像處理(第二版).pdf
python就業班學習視頻,從入門到實戰項目
2019最新《PyTorch自然語言處理》英、中文版PDF+源碼
《21個項目玩轉深度學習:基於TensorFlow的實踐詳解》完整版PDF+附書代碼
《深度學習之pytorch》pdf+附書源碼
PyTorch深度學習快速實戰入門《pytorch-handbook》
【下載】豆瓣評分8.1,《機器學習實戰:基於Scikit-Learn和TensorFlow》
《Python數據分析與挖掘實戰》PDF+完整源碼
汽車行業完整知識圖譜項目實戰視頻(全23課)
李沐大神開源《動手學深度學習》,加州伯克利深度學習(2019春)教材
筆記、代碼清晰易懂!李航《統計學習方法》最新資源全套!
《神經網絡與深度學習》最新2018版中英PDF+源碼
將機器學習模型部署為REST API
FashionAI服裝屬性標籤圖像識別Top1-5方案分享
重要開源!CNN-RNN-CTC 實現手寫漢字識別
yolo3 檢測出圖像中的不規則漢字
同樣是機器學習算法工程師,你的面試為什麼過不了?
前海徵信大數據算法:風險概率預測
【Keras】完整實現『交通標誌』分類、『票據』分類兩個項目,讓你掌握深度學習圖像分類
VGG16遷移學習,實現醫學圖像識別分類工程項目
特徵工程(一)
特徵工程(二) :文本數據的展開、過濾和分塊
特徵工程(三):特徵縮放,從詞袋到 TF-IDF
特徵工程(四): 類別特徵
特徵工程(五): PCA 降維
特徵工程(六): 非線性特徵提取和模型堆疊
特徵工程(七):圖像特徵提取和深度學習
如何利用全新的決策樹集成級聯結構gcForest做特徵工程並打分?
Machine Learning Yearning 中文翻譯稿
螞蟻金服2018秋招-算法工程師(共四面)通過
全球AI挑戰-場景分類的比賽源碼(多模型融合)
斯坦福CS230官方指南:CNN、RNN及使用技巧速查(列印收藏)
python+flask搭建CNN在線識別手寫中文網站
中科院Kaggle全球文本匹配競賽華人第1名團隊-深度學習與特徵工程
不斷更新資源
深度學習、機器學習、數據分析、python
搜索公眾號添加: datayx
機大數據技術與機器學習工程
搜索公眾號添加: datanlp
長按圖片,識別二維碼