什麼是殘差神經網絡?
原則上,神經網絡的層數越多,應獲得越好的結果。一個更深層的網絡可以學到任何淺層的東西,甚至可能更多。如果對於給定的數據集,網絡無法通過添加更多的層來學習更多東西,那麼它就可以學習這些其他層的恆等映射(identity mappings)。這樣,它可以保留先前層中的信息,並且不會比較淺的層更糟糕。
但是,實際上情況並非如此。越深的網絡越難優化。隨著我們向網絡中添加層,我們在訓練過程中的難度也會增加;用於查找正確參數的優化算法也會變得越來越困難。隨著我們添加更多層,網絡將獲得更好的結果(直到某個時候為止)。然後,隨著我們繼續添加額外的層,準確性開始下降。
殘差網絡試圖通過添加所謂的skip connections來解決此問題。如前所述,更深層的網絡至少應該能夠學習恆等映射(identity mappings)。skip connections是這樣做的:它們從網絡中的一個點到另一點添加恆等映射,然後讓網絡僅學習額外的()。如果網絡沒有其他可以學習的東西,那麼它僅將()設為0。事實證明,對於網絡來說,學習一個更接近於0的映射比學習恆等映射更容易。
具有skip connection的塊稱為殘差塊,而殘差神經網絡(ResNet)只是這些塊的連接。
Keras Functional API簡介
可能您已經熟悉了Sequential類,它可以讓一個人很容易地構建一個神經網絡,只要把層一個接一個地堆疊起來,就像這樣:
但是,這種構建神經網絡的方式不足以滿足我們的需求。使用Sequential類,我們無法添加skip connections。Keras的Model類可與Functional API一起使用,以創建用於構建更複雜的網絡體系結構的層。
構造後,keras.layers.Input返回張量對象。Keras中的層對象也可以像函數一樣使用,以張量對象作為參數來調用它。返回的對象是張量,然後可以將其作為輸入傳遞到另一層,依此類推。
舉個例子:
這種語法的真正用途是在使用所謂的「 Merge」層時,通過該層可以合併更多輸入張量。這些層中的一些例子是:Add,Subtract,Multiply,Average。我們在構建剩餘塊時需要的是Add。
使用的Add示例:
ResNet的Python實現
接下來,我們將實現一個ResNet和其普通(無skip connections)副本,以進行比較。
我們將在此處構建的ResNet具有以下結構:
形狀為(32,32,3)的輸入1個Conv2D層,64個filters2、5、5、2殘差塊的filters分別為64、128、256和512池大小= 4的AveragePooling2D層Flatten層10個輸出節點的Dense層它共有30個conv+dense層。所有的核大小都是3x3。我們在conv層之後使用ReLU激活和BatchNormalization。
我們首先創建一個輔助函數,將張量作為輸入並為其添加relu和批歸一化:
然後,我們創建一個用於構造殘差塊的函數。
create_res_net()函數將所有內容組合在一起。這是完整的代碼:
普通網絡以類似的方式構建,但它沒有skip connections,我們也不使用residual_block()幫助函數;一切都在create_plain_net()中完成。
plain network的Python代碼如下:
訓練CIFAR-10並查看結果
CIFAR-10是一個包含10個類別的32x32 rgb圖像的機器學習數據集。它包含了50k的訓練圖像和10k的測試圖像。
以下是來自每個類別的10張隨機圖片樣本:
我們將在這個機器學習數據集上對ResNet和PlainNet進行20個epoch的訓練,然後比較結果。
ResNet和PlainNet在訓練時間上沒有顯著差異。我們得到的結果如下所示。
因此,通過在該機器學習數據集上使用ResNet ,我們將驗證準確性提高了1.59%。在更深層的網絡上,差異應該更大。