從最優化的角度看待 Softmax 損失函數

2021-01-18 極市平臺

加入極市專業CV交流群,與6000+來自騰訊,華為,百度,北大,清華,中科院等名企名校視覺開發者互動交流!更有機會與李開復老師等大牛群內互動!

同時提供每月大咖直播分享、真實項目需求對接、乾貨資訊匯總,行業技術交流點擊文末「閱讀原文」立刻申請入群~


作者 | 王峰

來源 | https://zhuanlan.zhihu.com/p/45014864

本文經作者授權轉載,二次轉載請聯繫原作者。


Softmax交叉熵損失函數應該是目前最常用的分類損失函數了,在大部分文章中,Softmax交叉熵損失函數都是從概率角度來解釋的,本周二極市就推送了一篇Softmax相關文章:一文道盡softmax loss及其變種。


本文將嘗試從最優化的角度來推導出Softmax交叉熵損失函數,希望能夠啟發出更多的研究思路。


一般而言,最優化的問題通常需要構造一個目標函數,然後尋找能夠使目標函數取得最大/最小值的方法。目標函數往往難以優化,所以有了各種relax、smooth的方法,例如使用L1範數取代L0範數、使用sigmoid取代階躍函數等等。


那麼我們就要思考一個問題:使用神經網絡進行多分類(假設為 C 類)時的目標函數是什麼?神經網絡的作用是學習一個非線性函數 f(x) ,將輸入轉換成我們希望的輸出。這裡我們不考慮網絡結構,只考慮分類器(也就是損失函數)的話,最簡單的方法莫過於直接輸出一維的類別序號 。而這個方法的缺點顯而易見:我們事先並不知道這些類別之間的關係,而這樣做默認了相近的整數的類是相似的,為什麼第2類的左右分別是第1類和第3類,也許第2類跟第5類更為接近呢?


為了解決這個問題,可以將各個類別的輸出獨立開來,不再只輸出1個數而是輸出 C個分數(某些文章中叫作logit[1],但我感覺這個詞用得沒什麼道理,參見評論),每個類別佔據一個維度,這樣就沒有誰與誰更近的問題了。那麼如果讓一個樣本的真值標籤(ground-truth label)所對應的分數比其他分數更大,就可以通過比較 C個分數的大小來判斷樣本的類別了。這裡沿用我的論文[2]使用的名詞,稱真值標籤對應的類別分數為目標分數(target score),其他的叫非目標分數(non-target score)。


這樣我們就得到了一個優化目標:

輸出C個分數,使目標分數比非目標分數更大。


換成數學描述,設 為真值標籤的序號,那優化目標即為:


得到了目標函數之後,就要考慮優化問題了。我們可以給一個負的梯度,給其他所有一個正的梯度,經過梯度下降法,即可使 升高而 下降。為了控制整個神經網絡的幅度,不可以讓 無限地上升或下降,所以我們利用max函數,讓在  剛剛超過時就停止上升:


然而這樣做往往會使模型的泛化性能比較差,我們在訓練集上才剛剛讓 超過 ,那測試集很可能就不會超過。借鑑svm裡間隔的概念,我們添加一個參數,讓 比 大過一定的數值才停止:


這樣我們就推導出了hinge loss...唔,好像跑題了,我們本來不是要說Softmax的麼...不過既然跑題了就多說點,為什麼hinge loss在SVM時代大放異彩,但在神經網絡時代就不好用了呢?主要就是因為svm時代我們用的是二分類,通過使用一些小技巧比如1 vs 1、1 vs n等方式來做多分類問題。而如論文[3]這樣直接把hinge loss應用在多分類上的話,當類別數 特別大時,會有大量的非目標分數得到優化,這樣每次優化時的梯度幅度不等且非常巨大,極易梯度爆炸。


其實要解決這個梯度爆炸的問題也不難,我們把優化目標換一種說法:

輸出C個分數,使目標分數比最大的非目標分數更大。


跟之前相比,多了一個限制詞「最大的」,但其實我們的目標並沒有改變,「目標分數比最大的非目標分數更大」實際上等價於「目標分數比所有非目標分數更大」。這樣我們的損失函數就變成了:


在優化這個損失函數時,每次最多只會有一個+1的梯度和一個-1的梯度進入網絡,梯度幅度得到了限制。但這樣修改每次優化的分數過少,會使得網絡收斂極其緩慢,這時就又要祭出smooth大法了。那麼max函數的smooth版是什麼?有同學會脫口而出:softmax!恭喜你答錯了...


這裡出現了一個經典的歧義,softmax實際上並不是max函數的smooth版,而是one-hot向量(最大值為1,其他為0)的smooth版。其實從輸出上來看也很明顯,softmax的輸出是個向量,而max函數的輸出是一個數值,不可能直接用softmax來取代max。max函數真正的smooth版本是LogSumExp函數(LogSumExp:https://en.wikipedia.org/wiki/LogSumExp),對此感興趣的讀者還可以看看這個博客:尋求一個光滑的最大值函數(https://kexue.fm/archives/3290)。


使用LogSumExp函數取代max函數:


LogSumExp函數的導數恰好為softmax函數:


經過這一變換,給予非目標分數的1的梯度將會通過LogSumExp函數傳播給所有的非目標分數,各個非目標分數得到的梯度是通過softmax函數進行分配的,較大的非目標分數會得到更大的梯度使其更快地下降。這些非目標分數的梯度總和為1,目標分數得到的梯度為-1,總和為0,絕對值和為2,這樣我們就有效地限制住了梯度的總幅度。


LogSumExp函數值是大於等於max函數值的,而且等於取到的條件也是非常苛刻的(具體情況還是得看我的博士論文,這裡公式已經很多了,再寫就沒法看了),所以使用LogSumExp函數相當於變相地加了一定的 m。但這往往還是不夠的,我們可以選擇跟hinge loss一樣添加一個 ,那樣效果應該也會不錯,不過softmax交叉熵損失走的是另一條路:繼續smooth。


注意到ReLU函數 也有一個smooth版,即softplus函數 。使用softplus函數之後,即使  超過了LogSumExp函數,仍會得到一點點梯度讓  繼續上升,這樣其實也是變相地又增加了一點  ,使得泛化性能有了一定的保障。替換之後就可以得到:


這個就是大家所熟知的softmax交叉熵損失函數了。在經過兩步smooth化之後,我們將一個難以收斂的函數逐步改造成了softmax交叉熵損失函數,解決了原始的目標函數難以優化的問題。從這個推導過程中我們可以看出smooth化不僅可以讓優化更暢通,而且還變相地在類間引入了一定的間隔,從而提升了泛化性能。


至於如何利用這個推導來對損失函數進行修改和一些進一步的分析,未完待續...


[1] Pereyra G, Tucker G, Chorowski J, et al. Regularizing neural networks by penalizing confident output distributions[J]. arXiv preprint arXiv:1701.06548, 2017.

[2] Wang F, Cheng J, Liu W, et al. Additive margin softmax for face verification[J]. IEEE Signal Processing Letters, 2018, 25(7): 926-930.

[3] Tang Y. Deep learning using linear support vector machines[J]. arXiv preprint arXiv:1306.0239, 2013.





*延伸閱讀

每月大咖直播分享、真實項目需求對接、乾貨資訊匯總,行業技術交流點擊左下角「閱讀原文」立刻申請入群~


覺得有用麻煩給個好看啦~  

相關焦點

  • softmax 損失函數 & 參數更新詳解
    要點回歸softmax進階多分類 - 基礎理解softmax多分類實現圖解softmax 損失函數產生及理解對參數求偏導推導及更新要點回歸:邏輯回歸二分類用sigmoid變換成預測單個「概率」,損失函數為交叉熵,用梯度下降求解參數wbsoftmax多分類用softmax
  • 通過對抗損失函數來降低對抗損失函數的效用
    該文將進一步通過對抗損失函數來降低對抗損失函數的效用,並且使用優化ranknet-2模型實現對抗損失函數。同時將softmax函數映射到了循環網絡中,即將損失函數映射到了dnn整體和用戶的得分之間的差值之間的回歸梯度,增強網絡網絡的泛化能力。
  • 周末AI課堂 理解softmax函數 | 機器學習你會遇到的「坑」
    如果我們選擇添加sigmoid作為激活函數的隱層,那麼從整個神經網絡的角度來說,sigmoid函數在隱藏單元的作用只是提供非線性,在輸出單元的作用卻是作為分類器。很多人會把softmax當作sigmoid函數的推廣,但這樣的理解並不是自然的,比如面對多分類問題,softmax函數會輸出每個結果的概率,和為1;但sigmoid函數會對每一個數據輸出屬於該結果和不屬於該結果的概率,這個概率之和為1,但對於每個結果而言,概率之和並不為1,事實上,sigmoid函數的背後有著更為本質的原因,我們可以分別從廣義線性模型的角度來理解softmax
  • 神經網絡中的各種損失函數介紹
    不同的損失函數可用於不同的目標。在這篇文章中,我將帶你通過一些示例介紹一些非常常用的損失函數。這篇文章提到的一些參數細節都屬於tensorflow或者keras的實現細節。損失函數的簡要介紹損失函數有助於優化神經網絡的參數。
  • 深度學習中常見的損失函數
    而邏輯回歸的推導中,它假設樣本服從於伯努利分布(0-1分布),然後求得滿足該分布的似然函數,接著求取對數等(Log損失函數中採用log就是因為求解過中使用了似然函數,為了求解方便而添加log,因為添加log並不改變其單調性)。但邏輯回歸併沒有極大化似然函數,而是轉變為最小化負的似然函數,因此有了上式。已知邏輯函數(sigmoid函數)為:
  • RBF-Softmax:讓模型學到更具表達能力的類別表示
    因為傳統的softmax損失優化的是類內和類間的差異的最大化,也就是類內和類間的距離(logits)的差別的最大化,沒有辦法得到表示類別的向量表示來對類內距離進行正則化。之前的方法都是想辦法增加類內的內聚性,而忽視了不同的類別之間的關係。
  • PYNQ中實現SoftMax函數加速器
    孫齊偉本文引用地址:http://www.eepw.com.cn/article/201905/401026.htm  (西南交通大學 信息科學與技術學院,四川 成都 611756)  摘要:SoftMax函數通常在深度學習中作為激活函數使用,但其計算涉及自然指數和除法運算,傳統PC機上計算較慢,拖累了一個神經網絡的訓練。
  • 機器學習算法中的7個損失函數的詳細指南
    損失函數將決策映射到其相關成本決定走上坡的路徑將耗費我們的體力和時間。決定走下坡的路徑將使我們受益。因此,下坡的成本是更小的。在有監督的機器學習算法中,我們希望在學習過程中最小化每個訓練樣例的誤差。這是使用梯度下降等一些優化策略完成的。
  • Python機器學習算法中的7個損失函數的詳細指南
    這是使用梯度下降等一些優化策略完成的。而這個誤差來自損失函數。損失函數(Loss Function)和成本函數(Cost Function)之間有什麼區別?在此強調這一點,儘管成本函數和損失函數是同義詞並且可以互換使用,但它們是不同的。損失函數用於單個訓練樣本。它有時也稱為誤差函數(error function)。
  • 機器學習經典損失函數比較
    我們常常將最小化的函數稱為損失函數,它主要用于衡量模型的預測能力。在尋找最小值的過程中,我們最常用的方法是梯度下降法,這種方法很像從山頂下降到山谷最低點的過程。 雖然損失函數描述了模型的優劣為我們提供了優化的方向,但卻不存在一個放之四海皆準的損失函數。損失函數的選取依賴於參數的數量、局外點、機器學習算法、梯度下降的效率、導數求取的難易和預測的置信度等方面。
  • TensorFlow2.0(8):誤差計算——損失函數總結
    1 均方差損失函數:MSE 均方誤差(Mean Square Error),應該是最常用的誤差計算方法了,數學公式為:其中,0.4], dtype=float32)>loss_mse_2 = tf.reduce_mean(loss_mse_1)loss_mse_2<tf.Tensor: id=24, shape=(), dtype=float32, numpy=0.4>一般而言,均方誤差損失函數比較適用於回歸問題中
  • 機器學習之模型評估(損失函數)
    機器學習的模型評估,主要包括兩部分:損失函數是性能度量的一個部分,而損失函數又分很多種,因此單獨作為一篇介紹損失函數。機器學習中的所有算法都依賴於最小化或最大化一個函數,我們稱之為「目標函數」。最小化的函數稱為「損失函數」,損失函數衡量的是模型預測預期結果的能力。通常情況下,損失函數分為二類:回歸問題和分類問題。
  • 理解損失函數(代碼篇)機器學習你會遇到的「坑」
    全文共1950字,預計學習時長4分鐘在上一節,我們主要講解了替代損失(Surrogate loss)由來和性質,明白了機器學習中損失函數定義的本質,我們先對回歸任務總結一下常用的損失函數:均方誤差(MSE):
  • 「技術綜述」一文道盡softmax loss及其變種
    本文首發於知乎專欄《有三AI學院》,https://zhuanlan.zhihu.com/c_151876233今天來說說softmax loss以及它的變種1 softmax losssoftmax loss是我們最熟悉的loss之一了,分類任務中使用它,分割任務中依然使用它。
  • 楊植麟等人瞄準softmax瓶頸,新方法顧表達性和高效性
    但是,MoS 的內存和時間成本均高於 softmax,這使得它在計算資源有限的情況下實際應用性減弱。MoS →Mixtape為了降低 MoS 的計算成本,最近楊植麟等人提出了一種高效解決 softmax 瓶頸的新型輸出層 Mixtape。Mixtape 可作為額外層嵌入到任意現有網絡的交叉熵損失函數之前。
  • CMU楊植麟等人再次瞄準softmax瓶頸,Mixtape兼顧表達性和高效性
    在 10-30K 的詞彙量下,使用 Mixtape 的網絡僅比基於 softmax 的網絡慢 20%-34%,其困惑度和翻譯質量均優於 softmax。softmax 帶給我們的苦與樂大量神經網絡使用 softmax 作為標準輸出層,包括大部分神經語言模型。
  • 過度深入研究自然梯度優化
    在這個算法的最簡單的版本中,我們取一個標量,假設是0.1,然後乘以關於損失的梯度。我們的梯度,記住,實際上是一個矢量-損失的梯度相對於模型中的每個參數向量-損失的梯度模型中,當我們將它乘以一個標量時,我們用歐幾裡得的方法,將沿每個參數軸的更新乘以相同的固定量。在梯度下降的最基本版本中,我們在學習過程中使用相同的步長。但是,這真的有意義嗎?
  • 簡單的交叉熵損失函數,你真的懂了嗎?
    沒關係,接下來我將儘可能以最通俗的語言回答上面這幾個問題。1. 交叉熵損失函數的數學原理我們知道,在二分類問題模型:例如邏輯回歸「Logistic Regression」、神經網絡「Neural Network」等,真實樣本的標籤為 [0,1],分別表示負類和正類。