使用resnet, inception3進行fine-tune出現訓練集準確率很高但驗證集很低的問題

2021-02-19 機器學習AI算法工程




向AI轉型的程式設計師都關注了這個號👇👇👇

機器學習AI算法工程   公眾號:datayx

最近用keras跑基於resnet50,inception3的一些遷移學習的實驗,遇到一些問題。通過查看github和博客發現是由於BN層導致的,國外已經有人總結並提了一個PR(雖然並沒有被merge到Keras官方庫中),並寫了一篇博客,也看到知乎有人翻譯了一遍:Keras的BN你真的凍結對了嗎

https://zhuanlan.zhihu.com/p/56225304

當保存模型後再加載模型去預測時發現與直接預測結果不一致也可能是BN層的問題。

總結:

顯式設置

不可否認的是,默認的Frozen的BN的行為在遷移學習中確實是有training這個坑存在的,個人認為fchollet的修複方法更簡單一點,並且這種方式達到的效果和使用預訓練網絡提取特徵,單獨訓練分類層達到的效果是一致的,當你真的想要凍結BN層的時候,這種方式更符合凍結的這個動機;但在測試時使用新數據集的移動均值和方差一定程度上也是一種domain adaption。

譯文:
雖然Keras節省了我們很多編碼時間,但Keras中BN層的默認行為非常怪異,坑了我(此處及後續的「我」均指原文作者)很多次。Keras的默認行為隨著時間發生過許多的變化,但仍然有很多問題以至於現在Keras的GitHub上還掛著幾個相關的issue。在這篇文章中,我會構建一個案例來說明為什麼Keras的BN層對遷移學習並不友好,並給出對Keras BN層的一個修復補丁,以及修復後的實驗效果。

1. Introduction

這一節我會簡要介紹遷移學習和BN層,以及learning_phase的工作原理,Keras BN層在各個版本中的變化。如果你已經了解過這些知識,可以直接跳到第二節(譯者註:1.3和1.4跟這個問題還是比較相關的,不全是背景)。

1.1 遷移學習在深度學習中非常重要

深度學習在過去廣受詬病,原因之一就是它需要太多的訓練數據了。解決這個限制的方法之一就是遷移學習。

假設你現在要訓練一個分類器來解決貓狗二分類問題,其實並不需要幾百萬張貓貓狗狗的圖片。你可以只對預訓練模型頂部的幾層卷積層進行微調。因為預訓練模型是用圖像數據訓練的,底層卷積層可以識別線條,邊緣或者其他有用的模式作為特徵使用,所以可以用預訓練模型的權重作為一個很好的初始化值,或者只對模型的一部分用自己數據進行訓練。

Keras包含多種預訓練模型,並且很容易Fine-tune,更多細節可以查閱Keras官方文檔。

1.2 Batch Normalization是個啥

BN在2014年由Loffe和Szegedy提出,通過將前一層的輸出進行標準化解決梯度消失問題,並減小了訓練達到收斂所需的迭代次數,從而減少訓練時間,使得訓練更深的網絡成為可能。具體原理請看原論文,簡單來說,BN將每一層的輸入減去其在Batch中的均值,除以它的標準差,得到標準化的輸入,此外,BN也會為每個單元學習兩個因子來還原輸入。從下圖可以看到加了BN之後Loss下降更快,最後能達到的效果也更好。

1.3 Keras中的learning_phase是啥

網絡中有些層在訓練時和推導時的行為是不同的。最重要的兩個例子就是BN和Dropout層。對BN層,訓練時我們需要用mini batch的均值和方差來縮放輸入。在推導時,我們用訓練時統計到的累計均值和方差對推導的mini batch進行縮放。

Keras用learning_phase機制來告訴模型當前的所處的模式。假如用戶沒有手工指定的話,使用fit()時,網絡默認將learning_phase設為1,表示訓練模式。在預測時,比如調用predict()和evaluate()方法或者在fit()的驗證步驟中,網絡將learning_phase設為0,表示測試模式。用戶可以靜態地,在model或tensor添加到一個graph中之前,將learning_phase設為某個值(雖然官方不推薦手動設置),設置後,learning_phase就不可以修改了。

1.4 不同版本中的Keras是如何實現BN的

Keras中的BN訓練時統計當前Batch的均值和方差進行歸一化,並且使用移動平均法累計均值和方差,給測試集用於歸一化。

Keras中BN的行為變過幾次,但最重要的變更發生在2.1.3這個版本。2.1.3之前,當BN被凍結時(trainable=False),它仍然會更新mini batch的移動均值和方差,並用於測試,造成用戶的困擾(一副沒有凍結住的樣子)。

這種設計是錯誤的。考慮Conv1-Bn-Conv2-Conv3這樣的結構,如果BN層被凍結住了,應該無事發生才對。當Conv2處於凍結狀態時,如果我們部分更新了BN,那麼Conv2不能適應更新過的mini-batch的移動均值和方差,導致錯誤率上升。

在2.1.3及之後,當BN層被設為trainable=False時,Keras中不再更新mini batch的移動均值和方差,測試時使用的是預訓練模型中的移動均值和方差,從而達到凍結的效果, But is that enough? Not if you are using Transfer Learning.

2. 問題描述與解決方案

我會介紹問題的根源以及解決方案(一個Keras補丁)的技術實現。同時我也會提供一些樣例來說明打補丁前後模型的準確率變化。

2.1 問題描述

2.1.3版本後,當Keras中BN層凍結時,在訓練中會用mini batch的均值和方差統計值以執行歸一化。我認為更好的方式應該是使用訓練中得到的移動均值和方差(譯者註:這樣不就退回2.1.3之前的做法了)。原因和2.1.3的修復原因相同,由於凍結的BN的後續層沒有得到正確的訓練,使用mini batch的均值和方差統計值會導致較差的結果。

假設你沒有足夠的數據訓練一個視覺模型,你準備用一個預訓練Keras模型來Fine-tune。但你沒法保證新數據集在每一層的均值和方差與舊數據集的統計值的相似性。注意哦,在當前的版本中,不管你的BN有沒有凍結,訓練時都會用mini-batch的均值和方差統計值進行批歸一化,而在測試時你也會用移動均值方差進行歸一化。因此,如果你凍結了底層並微調頂層,頂層均值和方差會偏向新數據集,而推導時,底層會使用舊數據集的統計值進行歸一化,導致頂層接收到不同程度的歸一化的數據。

如上圖所示,假設我們從Conv K+1層開始微調模型,凍結左邊1到k層。訓練中,1到K層中的BN層會用訓練集的mini batch統計值來做歸一化,然而,由於每個BN的均值和方差與舊數據集不一定接近,在Relu處的丟棄的數據量與舊數據集會有很大區別,導致後續K+1層接收到的輸入和舊數據集的輸入範圍差別很大,後續K+1層的初始權重不能恰當處理這種輸入,導致精度下降。儘管網絡在訓練中可以通過對K+1層的權重調節來適應這種變化,但在測試模式下,Keras會用預訓練數據集的均值和方差,改變K+1層的輸入分布,導致較差的結果。

2.2 如何檢查你是否受到了這個問題的影響

分別將learning_phase這個變量設置為1或0進行預測,如果結果有顯著的差別,說明你中招了。不過learning_phase這個參數通常不建議手工指定,learning_phase不會改變已經編譯後的模型的狀態,所以最好是新建一個乾淨的session,在定義graph中的變量之前指定learning_phase。

檢查AUC和ACC,如果acc只有50%但auc接近1(並且測試和訓練表現有明顯不同),很可能是BN迷之縮放的鍋。類似的,在回歸問題上你可以比較MSE和Spearman『s correlation來檢查。

2.3 如何修復

如果BN在測試時真的鎖住了,這個問題就能真正解決。實現上,需要用trainable這個標籤來真正控制BN的行為,而不僅是用learning_phase來控制。具體實現在GitHub上。

主要是通過安裝補丁:作者提供了三個版本的補丁,安裝自己需要的版本就可以

用了這個補丁之後,BN凍結後,在訓練時它不會使用mini batch均值方差統計值進行歸一化,而會使用在訓練中學習到的統計值,避免歸一化的突變導致準確率的下降**。如果BN沒有凍結,它也會繼續使用訓練集中得到的統計值。**

原文:
By applying the above fix, when a BN layer is frozen it will no longer use the mini-batch statistics but instead use the ones learned during training. As a result, there will be no discrepancy between training and test modes which leads to increased accuracy. Obviously when the BN layer is not frozen, it will continue using the mini-batch statistics during training.

2.4 評估這個補丁的影響

雖然這個補丁是最近才寫好的,但其中的思想已經在各種各樣的workaround中驗證過了。這些workaround包括:將模型分成兩部分,一部分凍結,一部分不凍結,凍結部分只過一遍提取特徵,訓練時只訓練不凍結的部分。為了增加說服力,我會給出一些例子來展示這個補丁的真實影響。

我會用一小塊數據來刻意過擬合模型,用相同的數據來訓練和驗證模型,那麼在訓練集和驗證集上都應該達到接近100%的準確率。

如果驗證的準確率低於訓練準確率,說明當前的BN實現在推導中是有問題的。

預處理在generator之外進行,因為keras2.1.5中有一個相關的bug,在2.1.6中修復了。

在推導時使用不同的learning_phase設置,如果兩種設置下準確率不同,說明確實中招了。

代碼如下:

輸出如下:

如上文所述,驗證集準確率確實要差一些。

訓練完成後,我們做了三個實驗,DYNAMIC LEARNING_PHASE是默認操作,由Keras內部機制動態決定learning_phase,static兩種是手工指定learning_phase,分為設為0和1.當learning_phase設為1時,驗證集的效果提升了,因為模型正是使用訓練集的均值和方差統計值來訓練的,而這些統計值與凍結的BN中存儲的值不同,凍結的BN中存儲的是預訓練數據集的均值和方差,不會在訓練中更新,會在測試中使用。這種BN的行為不一致性導致了推導時準確率下降。

加了補丁後的效果:

模型收斂得更快,改變learning_phase也不再影響模型的準確率了,因為現在BN都會使用訓練集的均值和方差進行歸一化。

2.5 這個修復在真實數據集上表現如何

我們用Keras預訓練的ResNet50,在CIFAR10上開展實驗,只訓練分類層10個epoch,以及139層以後5個epoch。沒有用補丁的時候準確率為87.44%,用了之後準確率為92.36%,提升了5個點。

2.6 其他層是否也要做類似的修復呢?

Dropout在訓練時和測試時的表現也不同,但Dropout是用來避免過擬合的,如果在訓練時也將其凍結在測試模式,Dropout就沒用了,所以Dropout被frozen時,我們還是讓它保持能夠隨機丟棄單元的現狀吧。

參考文獻:
https://zhuanlan.zhihu.com/p/56225304
http://blog.datumbox.com/the-batch-normalization-layer-of-keras-is-broken/

https://blog.csdn.net/wf592523813

閱讀過本文的人還看了以下文章:

TensorFlow 2.0深度學習案例實戰

基於40萬表格數據集TableBank,用MaskRCNN做表格檢測

《基於深度學習的自然語言處理》中/英PDF

Deep Learning 中文版初版-周志華團隊

【全套視頻課】最全的目標檢測算法系列講解,通俗易懂!

《美團機器學習實踐》_美團算法團隊.pdf

《深度學習入門:基於Python的理論與實現》高清中文PDF+源碼

特徵提取與圖像處理(第二版).pdf

python就業班學習視頻,從入門到實戰項目

2019最新《PyTorch自然語言處理》英、中文版PDF+源碼

《21個項目玩轉深度學習:基於TensorFlow的實踐詳解》完整版PDF+附書代碼

《深度學習之pytorch》pdf+附書源碼

PyTorch深度學習快速實戰入門《pytorch-handbook》

【下載】豆瓣評分8.1,《機器學習實戰:基於Scikit-Learn和TensorFlow》

《Python數據分析與挖掘實戰》PDF+完整源碼

汽車行業完整知識圖譜項目實戰視頻(全23課)

李沐大神開源《動手學深度學習》,加州伯克利深度學習(2019春)教材

筆記、代碼清晰易懂!李航《統計學習方法》最新資源全套!

《神經網絡與深度學習》最新2018版中英PDF+源碼

將機器學習模型部署為REST API

FashionAI服裝屬性標籤圖像識別Top1-5方案分享

重要開源!CNN-RNN-CTC 實現手寫漢字識別

yolo3 檢測出圖像中的不規則漢字

同樣是機器學習算法工程師,你的面試為什麼過不了?

前海徵信大數據算法:風險概率預測

【Keras】完整實現『交通標誌』分類、『票據』分類兩個項目,讓你掌握深度學習圖像分類

VGG16遷移學習,實現醫學圖像識別分類工程項目

特徵工程(一)

特徵工程(二) :文本數據的展開、過濾和分塊

特徵工程(三):特徵縮放,從詞袋到 TF-IDF

特徵工程(四): 類別特徵

特徵工程(五): PCA 降維

特徵工程(六): 非線性特徵提取和模型堆疊

特徵工程(七):圖像特徵提取和深度學習

如何利用全新的決策樹集成級聯結構gcForest做特徵工程並打分?

Machine Learning Yearning 中文翻譯稿

螞蟻金服2018秋招-算法工程師(共四面)通過

全球AI挑戰-場景分類的比賽源碼(多模型融合)

斯坦福CS230官方指南:CNN、RNN及使用技巧速查(列印收藏)

python+flask搭建CNN在線識別手寫中文網站

中科院Kaggle全球文本匹配競賽華人第1名團隊-深度學習與特徵工程

不斷更新資源

深度學習、機器學習、數據分析、python

 搜索公眾號添加: datayx  

機大數據技術與機器學習工程

 搜索公眾號添加: datanlp

長按圖片,識別二維碼

相關焦點

  • [PyTorch 學習筆記] 7.2 模型 Finetune
    本章代碼:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/finetune_resnet18.py這篇文章主要介紹了模型的 Finetune。
  • 【乾貨】PyTorch實例:用ResNet進行交通標誌分類
    section=gtsrb&subsection=dataset▌實驗方法我嘗試使用在ImageNet數據集上預訓練的ResNet34卷積神經網絡來進行遷移學習。我在fast.ai最新版本的「深入學習編碼器」課程中學到了解決計算機視覺問題的方法。
  • 乾貨| BERT fine-tune 終極實踐教程
    google此次開源的BERT是通過tensorflow高級API—— tf.estimator進行封裝(wrapper)的。因此對於不同數據集的適配,只需要修改代碼中的processor部分,就能進行代碼的訓練、交叉驗證和測試。以下是奇點機智技術團隊對BERT在中文數據集上的fine tune終極實踐教程。
  • 機器學習 - 訓練集、驗證集、測試集
    訓練集在訓練模型時可能會出現過擬合問題(過擬合指模型可以很好的匹配訓練數據但預測其它數據時效果不好),所以一般需要在訓練集中再分出一部分作為驗證集,用於評估模型的訓練效果和調整模型的超參數(hyperparameter)。如下圖,展示了一套數據集的一般分配方式:訓練集用於構建模型。
  • 「fine-tune」別理解成「好的調子」!
    大家好,今天我們分享一個非常有用且地道的表達——fine-tune, 這個短語的含義不是指「好的調子」,其正確的含義是:fine-tune 對…進行微調,調整She spent hours fine-tuning
  • PyTorch 學習筆記(五):Finetune和各層定製學習率
    而在實際應用中,我們通常採用一個已經訓練模型的模型的權值參數作為我們模型的初始化參數,也稱之為Finetune,更寬泛的稱之為遷移學習。遷移學習中的Finetune技術,本質上就是讓我們新構建的模型,擁有一個較好的權值初始值。
  • 使用FastAI 和即時頻率變換進行音頻分類
    目前深度學習模型能處理許多不同類型的問題,對於一些教程或框架用圖像分類舉例是一種流行的做法,常常作為類似「hello, world」 那樣的引例。FastAI 是一個構建在 PyTorch 之上的高級庫,用這個庫進行圖像分類非常容易,其中有一個僅用四行代碼就可訓練精準模型的例子。
  • 【總251期-量化講堂097】聊聊機器學習中訓練集,測試集,以及cross validation的概念
    訓練集和測試集的概念非常容易理解,但是在實做的過程中,我們往往不能將全部數據用於訓練模型的參數,否則我們將沒有數據集對該模型進行驗證,從而評估我們模型的預測效果。為了解決這一問題,我們經常採用如下的方法進行處理:1.
  • DL經典論文系列(二) AlexNet、VGG、GoogLeNet/Inception、ResNet
    網絡的深度是提高網絡性能的關鍵,但是隨著網絡深度的加深,梯度消失問題逐漸明顯,甚至出現退化現象。所謂退化就是深層網絡的性能竟然趕不上較淺的網絡。本文提出殘差結構,當輸入為x時其學習到的特徵記為H(x),現在希望可以學習到殘差F(x)= H(x) - x,因為殘差學習相比原始特徵直接學習更容易。當殘差為0時,此時僅僅做了恆等映射,至少網絡性能不會下降。
  • xmnlp v0.3.0 更新:使用及原理介紹
    這次更新幾乎把所有的模型都升級了,使用了近年來比較熱門的技術。以下是 ChangeLog:•重構詞法分析:一個深度模型統一分詞、詞性標註、命名體識別;調用接口不變;去除繁體的支持;去除自定義詞典的支持。•重構文本糾錯:基於人民日報語料構造負樣本,使用 RoBERTa finetune 識別錯誤詞,並使用 Mask Language Model 的特性進行正確詞的預測。
  • 使用sklearn-svm進行多分類
    實際上,svm經過合適的設計也可以運用於多分類問題,sklearn中的svm模塊封裝了libsvm和liblinear,本節我們利用它進行多分類。SVM回顧SVM算法最初是為二值分類問題設計的,當處理多類問題時,就需要構造合適的多類分類器。
  • 基礎入門,怎樣用PaddlePaddle優雅地寫VGG與ResNet
    一般來說,圖像分類通過手工提取特徵或特徵學習方法對整個圖像進行全部描述,然後使用分類器判別物體類別,因此如何提取圖像的特徵至關重要。基於深度學習的圖像分類方法,可以通過有監督或無監督的方式學習層次化的特徵描述,從而取代了手工設計或選擇圖像特徵的工作。
  • 重讀經典:完全解析特徵學習大殺器ResNet
    ,ResNet的出現能夠解決這個問題。首先,深度網絡優化是比較困難的,比如會出現梯度爆炸/梯度消失等問題。不過,這個問題已經被normalized initialization和batch normalization等措施解決得差不多了。但是新問題又來了:deeper network收斂是收斂了,卻出現了效果上的degradationdeeper network準確率飽和後,很快就退化了為什麼會這樣呢?