加入極市專業CV交流群,與6000+來自騰訊,華為,百度,北大,清華,中科院等名企名校視覺開發者互動交流!更有機會與李開復老師等大牛群內互動!
同時提供每月大咖直播分享、真實項目需求對接、乾貨資訊匯總,行業技術交流。點擊文末「閱讀原文」立刻申請入群~
作者 | 王峰
來源 | https://zhuanlan.zhihu.com/p/45014864
本文經作者授權轉載,二次轉載請聯繫原作者。
Softmax交叉熵損失函數應該是目前最常用的分類損失函數了,在大部分文章中,Softmax交叉熵損失函數都是從概率角度來解釋的,本周二極市就推送了一篇Softmax相關文章:一文道盡softmax loss及其變種。
本文將嘗試從最優化的角度來推導出Softmax交叉熵損失函數,希望能夠啟發出更多的研究思路。
一般而言,最優化的問題通常需要構造一個目標函數,然後尋找能夠使目標函數取得最大/最小值的方法。目標函數往往難以優化,所以有了各種relax、smooth的方法,例如使用L1範數取代L0範數、使用sigmoid取代階躍函數等等。
那麼我們就要思考一個問題:使用神經網絡進行多分類(假設為 C 類)時的目標函數是什麼?神經網絡的作用是學習一個非線性函數 f(x) ,將輸入轉換成我們希望的輸出。這裡我們不考慮網絡結構,只考慮分類器(也就是損失函數)的話,最簡單的方法莫過於直接輸出一維的類別序號 。而這個方法的缺點顯而易見:我們事先並不知道這些類別之間的關係,而這樣做默認了相近的整數的類是相似的,為什麼第2類的左右分別是第1類和第3類,也許第2類跟第5類更為接近呢?
為了解決這個問題,可以將各個類別的輸出獨立開來,不再只輸出1個數而是輸出 C個分數(某些文章中叫作logit[1],但我感覺這個詞用得沒什麼道理,參見評論),每個類別佔據一個維度,這樣就沒有誰與誰更近的問題了。那麼如果讓一個樣本的真值標籤(ground-truth label)所對應的分數比其他分數更大,就可以通過比較 C個分數的大小來判斷樣本的類別了。這裡沿用我的論文[2]使用的名詞,稱真值標籤對應的類別分數為目標分數(target score),其他的叫非目標分數(non-target score)。
這樣我們就得到了一個優化目標:
輸出C個分數,使目標分數比非目標分數更大。換成數學描述,設 為真值標籤的序號,那優化目標即為:
。
得到了目標函數之後,就要考慮優化問題了。我們可以給一個負的梯度,給其他所有一個正的梯度,經過梯度下降法,即可使 升高而 下降。為了控制整個神經網絡的幅度,不可以讓 無限地上升或下降,所以我們利用max函數,讓在 剛剛超過時就停止上升:
。
然而這樣做往往會使模型的泛化性能比較差,我們在訓練集上才剛剛讓 超過 ,那測試集很可能就不會超過。借鑑svm裡間隔的概念,我們添加一個參數,讓 比 大過一定的數值才停止:
。
這樣我們就推導出了hinge loss...唔,好像跑題了,我們本來不是要說Softmax的麼...不過既然跑題了就多說點,為什麼hinge loss在SVM時代大放異彩,但在神經網絡時代就不好用了呢?主要就是因為svm時代我們用的是二分類,通過使用一些小技巧比如1 vs 1、1 vs n等方式來做多分類問題。而如論文[3]這樣直接把hinge loss應用在多分類上的話,當類別數 特別大時,會有大量的非目標分數得到優化,這樣每次優化時的梯度幅度不等且非常巨大,極易梯度爆炸。
其實要解決這個梯度爆炸的問題也不難,我們把優化目標換一種說法:
輸出C個分數,使目標分數比最大的非目標分數更大。跟之前相比,多了一個限制詞「最大的」,但其實我們的目標並沒有改變,「目標分數比最大的非目標分數更大」實際上等價於「目標分數比所有非目標分數更大」。這樣我們的損失函數就變成了:
。
在優化這個損失函數時,每次最多只會有一個+1的梯度和一個-1的梯度進入網絡,梯度幅度得到了限制。但這樣修改每次優化的分數過少,會使得網絡收斂極其緩慢,這時就又要祭出smooth大法了。那麼max函數的smooth版是什麼?有同學會脫口而出:softmax!恭喜你答錯了...
這裡出現了一個經典的歧義,softmax實際上並不是max函數的smooth版,而是one-hot向量(最大值為1,其他為0)的smooth版。其實從輸出上來看也很明顯,softmax的輸出是個向量,而max函數的輸出是一個數值,不可能直接用softmax來取代max。max函數真正的smooth版本是LogSumExp函數(LogSumExp:https://en.wikipedia.org/wiki/LogSumExp),對此感興趣的讀者還可以看看這個博客:尋求一個光滑的最大值函數(https://kexue.fm/archives/3290)。
使用LogSumExp函數取代max函數:
,
LogSumExp函數的導數恰好為softmax函數:
。
經過這一變換,給予非目標分數的1的梯度將會通過LogSumExp函數傳播給所有的非目標分數,各個非目標分數得到的梯度是通過softmax函數進行分配的,較大的非目標分數會得到更大的梯度使其更快地下降。這些非目標分數的梯度總和為1,目標分數得到的梯度為-1,總和為0,絕對值和為2,這樣我們就有效地限制住了梯度的總幅度。
LogSumExp函數值是大於等於max函數值的,而且等於取到的條件也是非常苛刻的(具體情況還是得看我的博士論文,這裡公式已經很多了,再寫就沒法看了),所以使用LogSumExp函數相當於變相地加了一定的 m。但這往往還是不夠的,我們可以選擇跟hinge loss一樣添加一個 ,那樣效果應該也會不錯,不過softmax交叉熵損失走的是另一條路:繼續smooth。
注意到ReLU函數 也有一個smooth版,即softplus函數 。使用softplus函數之後,即使 超過了LogSumExp函數,仍會得到一點點梯度讓 繼續上升,這樣其實也是變相地又增加了一點 ,使得泛化性能有了一定的保障。替換之後就可以得到:
這個就是大家所熟知的softmax交叉熵損失函數了。在經過兩步smooth化之後,我們將一個難以收斂的函數逐步改造成了softmax交叉熵損失函數,解決了原始的目標函數難以優化的問題。從這個推導過程中我們可以看出smooth化不僅可以讓優化更暢通,而且還變相地在類間引入了一定的間隔,從而提升了泛化性能。
至於如何利用這個推導來對損失函數進行修改和一些進一步的分析,未完待續...
[1] Pereyra G, Tucker G, Chorowski J, et al. Regularizing neural networks by penalizing confident output distributions[J]. arXiv preprint arXiv:1701.06548, 2017.
[2] Wang F, Cheng J, Liu W, et al. Additive margin softmax for face verification[J]. IEEE Signal Processing Letters, 2018, 25(7): 926-930.
[3] Tang Y. Deep learning using linear support vector machines[J]. arXiv preprint arXiv:1306.0239, 2013.
*延伸閱讀
覺得有用麻煩給個好看啦~