原文:TensorFlow四種Cross Entropy算法實現和應用,作者授權CSDN轉載。
➤交叉熵介紹
交叉熵(Cross Entropy)是Loss函數的一種(也稱為損失函數或代價函數),用於描述模型預測值與真實值的差距大小,常見的Loss函數就是均方平方差(Mean Squared Error),定義如下。
平方差很好理解,預測值與真實值直接相減,為了避免得到負數取絕對值或者平方,再做平均就是均方平方差。注意這裡預測值需要經過sigmoid激活函數,得到取值範圍在0到1之間的預測值。
平方差可以表達預測值與真實值的差異,但在分類問題種效果並不如交叉熵好,原因可以參考這篇博文 。交叉熵的定義如下,截圖來自https://hit-scir.gitbooks.io/neural-networks-and-deep-learning-zh_cn/content/chap3/c3s1.html 。
上面的文章也介紹了交叉熵可以作為Loss函數的原因,首先是交叉熵得到的值一定是正數,其次是預測結果越準確值越小,注意這裡用於計算的「a」也是經過sigmoid激活的,取值範圍在0到1。如果label是1,預測值也是1的話,前面一項y * ln(a)就是1 * ln(1)等於0,後一項(1 - y) * ln(1 - a)也就是0 * ln(0)等於0,Loss函數為0,反之Loss函數為無限大非常符合我們對Loss函數的定義。
這裡多次強調sigmoid激活函數,是因為在多目標或者多分類的問題下有些函數是不可用的,而TensorFlow本身也提供了多種交叉熵算法的實現。
➤TensorFlow的交叉熵函數
TensorFlow針對分類問題,實現了四個交叉熵函數,分別是
tf.nn.sigmoid_cross_entropy_with_logits、tf.nn.softmax_cross_entropy_with_logits、tf.nn.sparse_softmax_cross_entropy_with_logits和tf.nn.weighted_cross_entropy_with_logits,詳細內容參考API文檔 https://www.tensorflow.org/versions/master/api_docs/python/nn.html#sparse_softmax_cross_entropy_with_logits
sigmoid_cross_entropy_with_logits詳解
我們先看sigmoid_cross_entropy_with_logits,為什麼呢,因為它的實現和前面的交叉熵算法定義是一樣的,也是TensorFlow最早實現的交叉熵算法。這個函數的輸入是logits和targets,logits就是神經網絡模型中的 W * X矩陣,注意不需要經過sigmoid,而targets的shape和logits相同,就是正確的label值,例如這個模型一次要判斷100張圖是否包含10種動物,這兩個輸入的shape都是[100, 10]。注釋中還提到這10個分類之間是獨立的、不要求是互斥,這種問題我們成為多目標,例如判斷圖片中是否包含10種動物,label值可以包含多個1或0個1,還有一種問題是多分類問題,例如我們對年齡特徵分為5段,只允許5個值有且只有1個值為1,這種問題可以直接用這個函數嗎?答案是不可以,我們先來看看sigmoid_cross_entropy_with_logits的代碼實現吧。
可以看到這就是標準的Cross Entropy算法實現,對W * X得到的值進行sigmoid激活,保證取值在0到1之間,然後放在交叉熵的函數中計算Loss。對於二分類問題這樣做沒問題,但對於前面提到的多分類,例如年輕取值範圍在0~4,目標值也在0~4,這裡如果經過sigmoid後預測值就限制在0到1之間,而且公式中的1 - z就會出現負數,仔細想一下0到4之間還不存在線性關係,如果直接把label值帶入計算肯定會有非常大的誤差。因此對於多分類問題是不能直接代入的,那其實我們可以靈活變通,把5個年齡段的預測用onehot encoding變成5維的label,訓練時當做5個不同的目標來訓練即可,但不保證只有一個為1,對於這類問題TensorFlow又提供了基於Softmax的交叉熵函數。
softmax_cross_entropy_with_logits詳解
Softmax本身的算法很簡單,就是把所有值用e的n次方計算出來,求和後算每個值佔的比率,保證總和為1,一般我們可以認為Softmax出來的就是confidence也就是概率,算法實現如下。
softmax_cross_entropy_with_logits和sigmoid_cross_entropy_with_logits很不一樣,輸入是類似的logits和lables的shape一樣,但這裡要求分類的結果是互斥的,保證只有一個欄位有值,例如CIFAR-10中圖片只能分一類而不像前面判斷是否包含多類動物。想一下問什麼會有這樣的限制?在函數頭的注釋中我們看到,這個函數傳入的logits是unscaled的,既不做sigmoid也不做softmax,因為函數實現會在內部更高效得使用softmax,對於任意的輸入經過softmax都會變成和為1的概率預測值,這個值就可以代入變形的Cross Entroy算法- y * ln(a) - (1 - y) * ln(1 - a)算法中,得到有意義的Loss值了。如果是多目標問題,經過softmax就不會得到多個和為1的概率,而且label有多個1也無法計算交叉熵,因此這個函數隻適合單目標的二分類或者多分類問題,TensorFlow函數定義如下。
再補充一點,對於多分類問題,例如我們的年齡分為5類,並且人工編碼為0、1、2、3、4,因為輸出值是5維的特徵,因此我們需要人工做onehot encoding分別編碼為00001、00010、00100、01000、10000,才可以作為這個函數的輸入。理論上我們不做onehot encoding也可以,做成和為1的概率分布也可以,但需要保證是和為1,和不為1的實際含義不明確,TensorFlow的C++代碼實現計劃檢查這些參數,可以提前提醒用戶避免誤用。
sparse_softmax_cross_entropy_with_logits詳解
sparse_softmax_cross_entropy_with_logits是softmax_cross_entropy_with_logits的易用版本,除了輸入參數不同,作用和算法實現都是一樣的。前面提到softmax_cross_entropy_with_logits的輸入必須是類似onehot encoding的多維特徵,但CIFAR-10、ImageNet和大部分分類場景都只有一個分類目標,label值都是從0編碼的整數,每次轉成onehot encoding比較麻煩,有沒有更好的方法呢?答案就是用sparse_softmax_cross_entropy_with_logits,它的第一個參數logits和前面一樣,shape是[batch_size, num_classes],而第二個參數labels以前也必須是[batch_size, num_classes]否則無法做Cross Entropy,這個函數改為限制更強的[batch_size],而值必須是從0開始編碼的int32或int64,而且值範圍是[0, num_class),如果我們從1開始編碼或者步長大於1,會導致某些label值超過這個範圍,代碼會直接報錯退出。這也很好理解,TensorFlow通過這樣的限制才能知道用戶傳入的3、6或者9對應是哪個class,最後可以在內部高效實現類似的onehot encoding,這只是簡化用戶的輸入而已,如果用戶已經做了onehot encoding那可以直接使用不帶「sparse」的softmax_cross_entropy_with_logits函數。
weighted_sigmoid_cross_entropy_with_logits詳解
weighted_sigmoid_cross_entropy_with_logits是sigmoid_cross_entropy_with_logits的拓展版,輸入參數和實現和後者差不多,可以多支持一個pos_weight參數,目的是可以增加或者減小正樣本在算Cross Entropy時的Loss。實現原理很簡單,在傳統基於sigmoid的交叉熵算法上,正樣本算出的值乘以某個係數接口,算法實現如下。
➤總結
這就是TensorFlow目前提供的有關Cross Entropy的函數實現,用戶需要理解多目標和多分類的場景,根據業務需求(分類目標是否獨立和互斥)來選擇基於sigmoid或者softmax的實現,如果使用sigmoid目前還支持加權的實現,如果使用softmax我們可以自己做onehot coding或者使用更易用的sparse_softmax_cross_entropy_with_logits函數。
TensorFlow提供的Cross Entropy函數基本cover了多目標和多分類的問題,但如果同時是多目標多分類的場景,肯定是無法使用softmax_cross_entropy_with_logits,如果使用sigmoid_cross_entropy_with_logits我們就把多分類的特徵都認為是獨立的特徵,而實際上他們有且只有一個為1的非獨立特徵,計算Loss時不如Softmax有效。這裡可以預測下,未來TensorFlow社區將會實現更多的op解決類似的問題,我們也期待更多人參與TensorFlow貢獻算法和代碼 :)
作者:陳迪豪,就職於小米,負責企業深度學習平臺搭建,參與過HBase、Docker、OpenStack等開源項目,目前專注於TensorFlow和Kubernetes社區。
歡迎技術投稿、約稿、給文章糾錯,請發送郵件至heyc@csdn.net