理解Batch Normalization

2020-12-06 騰訊網

  作者&編輯:李中梁

  引言

  上文

  提過不要在神經網絡中使用dropout層,用BN層可以獲得更好的模型。經典論文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出了Batch Normalization 批標準化的概念,towardsdatascience上一文《Intuit and Implement: Batch Normalization》詳細解釋了BN的原理,並通過在Cifar 100上的實驗證明了其有效性。全文編譯如下。

  神經網絡在訓練過程中的問題

  神經網絡在訓練過程中往往會遇到一些問題:

  問題1: 隨著網絡訓練,淺層的權重發生變化,導致深層的輸入變化很大。因此每層必須根據每批輸入的不同分布重新調整其權重。這減緩了模型訓練。如果我們可以使層的輸入分布更相似,那麼網絡可以專注於學習類別之間的差異。

  不同批次分布的另一個影響是梯度彌散。梯度彌散是一個大問題,特別是對於S形激活函數(sigmoid)。如果g(x)表示sigmoid激活函數,隨著 |x| 增加,g』(x)趨於零。

  問題2:當輸入分布變化時,神經元輸出也會變化。這導致神經元輸出偶爾波動到S形函數的可飽和區域。在那裡,神經元既不能更新自己的權重,也不能將梯度傳遞迴先前的層。那麼我們該如何保證神經元輸出到不飽和區域?

  如果我們可以將神經元輸出限制在零附近的區域,我們可以確保每個層在反向傳播期間都會返回一個有效的梯度。這將減少訓練時間和提高準確率。

  使用批量規範(BN)作為解決方案

  批量標準化減輕了不同層輸入對訓練的影響。通過歸一化神經元的輸出,激活函數將僅接收接近零的輸入。這確保了梯度的有效回傳,解決了第二個問題。

  批量歸一化將層輸出轉換為單位高斯分布。當這些輸出通過激活功能饋送時,層激活也將變得更加正常分布。

  由於上一層的輸出是下一層的輸入,因此層輸入的變化在不同批次輸間的變化將顯著減少。通過減少層的輸入的變化分布,我們解決了第一個問題。

  數學解釋

  通過批量歸一化,我們為每個激活函數尋找均值為0,方差為1的分布作為輸入。在訓練期間,我們將激活輸入x減去批次均值μ以實現零中心分布。

  接下來,我們取x並將其除以批處理方差和一個小數字,以防止除以零σ+ε。這可確保所有激活輸入分布方差為1。

  最後,我們將得到的x進行線性變換。這樣儘管在反向傳播期間網絡發生了變化,但仍能確保保持這種標準化效果。

  在測試模型時,我們不使用當前批次均值或方差,因為這會破壞模型。相反,我們計算訓練群體的移動均值和方差估計值。這些估計值是訓練期間計算的所有批次平均值和方差的平均值。

  批標準化的好處

  1.有助於減少具有可飽和非線性函數的網絡中的消失梯度問題。

  通過批標準化,我們確保任何激活函數的輸入不會進入飽和區域。批量歸一化將這些輸入的分布轉換為0-1高斯分布。

  2.正則化模型

  Ioffe和Svegeddy提出了這一主張,但沒有就此問題進行深入探討。也許這是歸一化層輸入的結果?

  3.允許更高的學習率

  通過防止在訓練期間消失梯度的問題,我們可以設置更高的學習率。批量標準化還降低了對參數標度的依賴性。大的學習速率可以增加層參數的規模,這導致梯度在反向傳播期間被回傳時放大。

  使用Keras實現Batch Normalization

  導入需要的庫

  數據加載和預處理

  我們使用了Cifar 100數據集,因為它具有一定的挑戰性,並且不會訓練太久。唯一的預處理是zero-centering和image variation generator。

  在Keras中構建模型

  我們的架構將包括堆疊的3x3卷積。每個網絡中有5個卷積塊。最後一層是一個完全連接的層,有100個節點與softmax激活。

  我們將構建4個不同的卷積網絡,每個網絡都具有sigmoid或ReLU激活以及批量標準化或不標準化。我們將比較每個網絡的驗證損失。

  模型訓練

  沒有BN的Sigmoid

  可以看到訓練無法收斂。有100個類,這個模型永遠不會比隨機猜測(10%準確率)獲得更好的性能。

  具有BN的Sigmoid

  與沒有批量標準化不同,該模型在培訓期間開始實施。這可能是批量標準化減輕消失梯度的結果。

  沒有BN的ReLU

  在沒有批量規範的情況下實施ReLU使得訓練初始效果不錯,然後模型收斂到非最優的局部最小值。

  具有BN的ReLU

  與sigmoid模型一樣,批量標準化提高了該網絡的訓練能力。

  架構比較

  我們清楚地看到了批量標準化的優勢。沒有批量標準化的ReLU和Sigmoid模型都無法保持訓練性能提升。這可能是漸變消失的結果。具有批量標準化的體系結構訓練得更快,並且比沒有批量標準化的體系結構表現更好。

  結論

  批量標準化減少了訓練時間並提高了神經網絡的穩定性。此效果適用於sigmoid和ReLU激活功能。

  代碼:https://github.com/harrisonjansma/Research-Computer-Vision/tree/master/07-28-18-Implementing-Batch-Norm

相關焦點

  • 神經網絡算法Batch Normalization的分析與展望 | 大牛講堂
    從這裡的分析,也可以看出來,BN網絡相對而言更加需要全局的徹底的隨機shuffle,如果沒有徹底的shuffle,幾條樣本總是出現在同一個mini batch中,那麼很可能有些樣本對於在經過標準化之後,在很多神經元的輸出總是零,從而難以將信號傳遞到頂層。這可能就是為什麼[1]提到增加了shuffle之後,結果在驗證集上有所變好。
  • 神經網絡算法BatchNormalization的分析與展望|大牛講堂
    同時這也可能為什麼當使用了BN+隨機shuffle之後dropout的作用在下降的一個原因,因為數據隨機的組合,利用mini batch統計量標準化之後,對於特定樣本的特定神經元可能隨機地為0,那麼和dropout的機制有類似的地方。但是如果使用很大的mini batch又會如何呢?
  • 批歸一化Batch Normalization的原理及算法
    batch中每個樣本的差異性越大,這種弊端就越嚴重。BN首先是把所有的samples的統計分布標準化,降低了batch內不同樣本的差異性,然後又允許batch內的各個samples有各自的統計分布。所以,BN的優點自然也就是允許網絡使用較大的學習速率進行訓練,加快網絡的訓練速度(減少epoch次數),提升效果。省去參數選擇的問題。
  • 如何理解深度學習分布式訓練中的large batch size與learning rate...
    雷鋒網 AI科技評論按,本文源自譚旭在知乎問題【如何理解深度學習分布式訓練中的large batch size與learning rate的關係?】下的回答,雷鋒網 AI科技評論獲其授權轉載。問題詳情:在深度學習進行分布式訓練時,常常採用同步數據並行的方式,也就是採用大的batch size進行訓練,但large batch一般較於小的baseline的batch size性能更差,請問如何理解調試learning rate能使large batch達到small batch同樣的收斂精度和速度?
  • Batch、Mini-batch和隨機梯度下降的區別和Python示例
    Mini-batch梯度下降想像一下,將您的數據集分成幾個batches。這樣,它就不必等到算法遍歷整個數據集後才更新權重和偏差,而是在每個所謂的Mini-batch結束時進行更新。這使得我們能夠快速將成本函數移至全局最小值,並在每個epoch中多次更新權重和偏差。最常見的Mini-batch大小是16、32、64、128、256和512。
  • batchplot批量列印怎麼用?Batchplot功能特色
    Batchplot(CAD批量列印工具)安裝步驟1、在華軍軟體園(其他也可)下載batchplot軟體包,解壓縮雙擊運行它的.exe文件。之後就打開了它的安裝嚮導,點擊下一步;2、閱讀軟體相關協議,點擊同意再點擊下一步;3、接著就來到軟體信息界面,點擊下一步。
  • 透徹分析批歸一化Batch Normalization強大作用
    BN對網絡的中間層執行白化本文只關注BN為什麼工作的這麼好,如果要詳細理解BN詳細算法,請閱讀另一篇文章《批歸一化Batch Normalization的原理及算法》,本文從以下六個方面來闡述批歸一化為什麼有如此好的效力:(1)激活函數
  • 一文詳解深度學習中的Normalization:BN/LN/WN
    的均值和方差,因而稱為 Batch Normalization。 ,但規範化的參數是一個 mini-batch 的一階統計量和二階統計量。這就要求 每一個 mini-batch 的統計量是整體統計量的近似估計,或者說每一個 mini-batch 彼此之間,以及和整體數據,都應該是近似同分布的。
  • 理解卷積神經網絡中的輸入與輸出形狀(Keras實現)
    即使我們從理論上理解了卷積神經網絡,在實際進行將數據擬合到網絡時,很多人仍然對其網絡的輸入和輸出形狀(shape)感到困惑。本文章將幫助你理解卷積神經網絡的輸入和輸出形狀。讓我們看看一個例子。CNN的輸入數據如下圖所示。