TensorFlow四種Cross Entropy算法實現和應用

2021-01-14 CSDN

 原文:TensorFlow四種Cross Entropy算法實現和應用,作者授權CSDN轉載。 


交叉熵介紹


交叉熵(Cross Entropy)是Loss函數的一種(也稱為損失函數或代價函數),用於描述模型預測值與真實值的差距大小,常見的Loss函數就是均方平方差(Mean Squared Error),定義如下。



平方差很好理解,預測值與真實值直接相減,為了避免得到負數取絕對值或者平方,再做平均就是均方平方差。注意這裡預測值需要經過sigmoid激活函數,得到取值範圍在0到1之間的預測值。


平方差可以表達預測值與真實值的差異,但在分類問題種效果並不如交叉熵好,原因可以參考這篇博文 。交叉熵的定義如下,截圖來自https://hit-scir.gitbooks.io/neural-networks-and-deep-learning-zh_cn/content/chap3/c3s1.html 。



上面的文章也介紹了交叉熵可以作為Loss函數的原因,首先是交叉熵得到的值一定是正數,其次是預測結果越準確值越小,注意這裡用於計算的「a」也是經過sigmoid激活的,取值範圍在0到1。如果label是1,預測值也是1的話,前面一項y * ln(a)就是1 * ln(1)等於0,後一項(1 - y) * ln(1 - a)也就是0 * ln(0)等於0,Loss函數為0,反之Loss函數為無限大非常符合我們對Loss函數的定義。


這裡多次強調sigmoid激活函數,是因為在多目標或者多分類的問題下有些函數是不可用的,而TensorFlow本身也提供了多種交叉熵算法的實現。


TensorFlow的交叉熵函數


TensorFlow針對分類問題,實現了四個交叉熵函數,分別是


tf.nn.sigmoid_cross_entropy_with_logits、tf.nn.softmax_cross_entropy_with_logits、tf.nn.sparse_softmax_cross_entropy_with_logits和tf.nn.weighted_cross_entropy_with_logits,詳細內容參考API文檔 https://www.tensorflow.org/versions/master/api_docs/python/nn.html#sparse_softmax_cross_entropy_with_logits


sigmoid_cross_entropy_with_logits詳解


我們先看sigmoid_cross_entropy_with_logits,為什麼呢,因為它的實現和前面的交叉熵算法定義是一樣的,也是TensorFlow最早實現的交叉熵算法。這個函數的輸入是logits和targets,logits就是神經網絡模型中的 W * X矩陣,注意不需要經過sigmoid,而targets的shape和logits相同,就是正確的label值,例如這個模型一次要判斷100張圖是否包含10種動物,這兩個輸入的shape都是[100, 10]。注釋中還提到這10個分類之間是獨立的、不要求是互斥,這種問題我們成為多目標,例如判斷圖片中是否包含10種動物,label值可以包含多個1或0個1,還有一種問題是多分類問題,例如我們對年齡特徵分為5段,只允許5個值有且只有1個值為1,這種問題可以直接用這個函數嗎?答案是不可以,我們先來看看sigmoid_cross_entropy_with_logits的代碼實現吧。



可以看到這就是標準的Cross Entropy算法實現,對W * X得到的值進行sigmoid激活,保證取值在0到1之間,然後放在交叉熵的函數中計算Loss。對於二分類問題這樣做沒問題,但對於前面提到的多分類,例如年輕取值範圍在0~4,目標值也在0~4,這裡如果經過sigmoid後預測值就限制在0到1之間,而且公式中的1 - z就會出現負數,仔細想一下0到4之間還不存在線性關係,如果直接把label值帶入計算肯定會有非常大的誤差。因此對於多分類問題是不能直接代入的,那其實我們可以靈活變通,把5個年齡段的預測用onehot encoding變成5維的label,訓練時當做5個不同的目標來訓練即可,但不保證只有一個為1,對於這類問題TensorFlow又提供了基於Softmax的交叉熵函數。


softmax_cross_entropy_with_logits詳解


Softmax本身的算法很簡單,就是把所有值用e的n次方計算出來,求和後算每個值佔的比率,保證總和為1,一般我們可以認為Softmax出來的就是confidence也就是概率,算法實現如下。



softmax_cross_entropy_with_logits和sigmoid_cross_entropy_with_logits很不一樣,輸入是類似的logits和lables的shape一樣,但這裡要求分類的結果是互斥的,保證只有一個欄位有值,例如CIFAR-10中圖片只能分一類而不像前面判斷是否包含多類動物。想一下問什麼會有這樣的限制?在函數頭的注釋中我們看到,這個函數傳入的logits是unscaled的,既不做sigmoid也不做softmax,因為函數實現會在內部更高效得使用softmax,對於任意的輸入經過softmax都會變成和為1的概率預測值,這個值就可以代入變形的Cross Entroy算法- y * ln(a) - (1 - y) * ln(1 - a)算法中,得到有意義的Loss值了。如果是多目標問題,經過softmax就不會得到多個和為1的概率,而且label有多個1也無法計算交叉熵,因此這個函數隻適合單目標的二分類或者多分類問題,TensorFlow函數定義如下。



再補充一點,對於多分類問題,例如我們的年齡分為5類,並且人工編碼為0、1、2、3、4,因為輸出值是5維的特徵,因此我們需要人工做onehot encoding分別編碼為00001、00010、00100、01000、10000,才可以作為這個函數的輸入。理論上我們不做onehot encoding也可以,做成和為1的概率分布也可以,但需要保證是和為1,和不為1的實際含義不明確,TensorFlow的C++代碼實現計劃檢查這些參數,可以提前提醒用戶避免誤用。


sparse_softmax_cross_entropy_with_logits詳解


sparse_softmax_cross_entropy_with_logits是softmax_cross_entropy_with_logits的易用版本,除了輸入參數不同,作用和算法實現都是一樣的。前面提到softmax_cross_entropy_with_logits的輸入必須是類似onehot encoding的多維特徵,但CIFAR-10、ImageNet和大部分分類場景都只有一個分類目標,label值都是從0編碼的整數,每次轉成onehot encoding比較麻煩,有沒有更好的方法呢?答案就是用sparse_softmax_cross_entropy_with_logits,它的第一個參數logits和前面一樣,shape是[batch_size, num_classes],而第二個參數labels以前也必須是[batch_size, num_classes]否則無法做Cross Entropy,這個函數改為限制更強的[batch_size],而值必須是從0開始編碼的int32或int64,而且值範圍是[0, num_class),如果我們從1開始編碼或者步長大於1,會導致某些label值超過這個範圍,代碼會直接報錯退出。這也很好理解,TensorFlow通過這樣的限制才能知道用戶傳入的3、6或者9對應是哪個class,最後可以在內部高效實現類似的onehot encoding,這只是簡化用戶的輸入而已,如果用戶已經做了onehot encoding那可以直接使用不帶「sparse」的softmax_cross_entropy_with_logits函數。


weighted_sigmoid_cross_entropy_with_logits詳解


weighted_sigmoid_cross_entropy_with_logits是sigmoid_cross_entropy_with_logits的拓展版,輸入參數和實現和後者差不多,可以多支持一個pos_weight參數,目的是可以增加或者減小正樣本在算Cross Entropy時的Loss。實現原理很簡單,在傳統基於sigmoid的交叉熵算法上,正樣本算出的值乘以某個係數接口,算法實現如下。



總結


這就是TensorFlow目前提供的有關Cross Entropy的函數實現,用戶需要理解多目標和多分類的場景,根據業務需求(分類目標是否獨立和互斥)來選擇基於sigmoid或者softmax的實現,如果使用sigmoid目前還支持加權的實現,如果使用softmax我們可以自己做onehot coding或者使用更易用的sparse_softmax_cross_entropy_with_logits函數。


TensorFlow提供的Cross Entropy函數基本cover了多目標和多分類的問題,但如果同時是多目標多分類的場景,肯定是無法使用softmax_cross_entropy_with_logits,如果使用sigmoid_cross_entropy_with_logits我們就把多分類的特徵都認為是獨立的特徵,而實際上他們有且只有一個為1的非獨立特徵,計算Loss時不如Softmax有效。這裡可以預測下,未來TensorFlow社區將會實現更多的op解決類似的問題,我們也期待更多人參與TensorFlow貢獻算法和代碼 :)

作者:陳迪豪,就職於小米,負責企業深度學習平臺搭建,參與過HBase、Docker、OpenStack等開源項目,目前專注於TensorFlow和Kubernetes社區。

歡迎技術投稿、約稿、給文章糾錯,請發送郵件至heyc@csdn.net

相關焦點

  • TensorFlow四種Cross Entropy算法的實現和應用
    原文:TensorFlow四種Cross Entropy算法實現和應用,作者授權CSDN轉載。 TensorFlow的交叉熵函數TensorFlow針對分類問題,實現了四個交叉熵函數,分別是詳細內容請參考API文檔 https://www.tensorflow.org/versions/master/api_docs/python/nn.html#sparse_softmax_cross_entropy_with_logits
  • TensorFlow極速入門
    最後給出了在 tensorflow 中建立一個機器學習模型步驟,並用一個手寫數字識別的例子進行演示。1、tensorflow是什麼?tensorflow 是 google 開源的機器學習工具,在2015年11月其實現正式開源,開源協議Apache 2.0。
  • 【強化學習實戰】基於gym和tensorflow的強化學習算法實現
    1新智元推薦【新智元導讀】知乎專欄強化學習大講堂作者郭憲博士開講《強化學習從入門到進階》,我們為您節選了其中的第二節《基於gym和tensorflow的強化學習算法實現》,希望對您有所幫助。同時,由郭憲博士等擔任授課教師的深度強化學習國慶集訓營也將於 10 月 2 日— 6 日在北京舉辦。
  • 手把手教你用 TensorFlow 實現卷積神經網絡(附代碼)
    from tensorflow.examples.tutorials.mnist import input_data  import tensorflow as tf  mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)# 讀取圖片數據集  sess = tf.InteractiveSession
  • 谷歌正式發布TensorFlow 1.5,究竟提升了哪些功能?
    GitHub 地址:https://github.com/tensorflow/tensorflow/releases/tag/v1.5.0原始碼(zip):https://github.com/tensorflow/tensorflow/archive/v1.5.0.zip原始碼(tar.gz):https://github.com/tensorflow/tensorflow
  • Tensorflow 2.0 即將入場
    而就在即將到來的2019年,Tensorflow 2.0將正式入場,給暗流湧動的框架之爭再燃一把火。如果說兩代Tensorflow有什麼根本不同,那應該就是Tensorflow 2.0更注重使用的低門檻,旨在讓每個人都能應用機器學習技術。
  • 程式設計師1小時完成深度學習Resnet,谷歌tensorflow多次圖像大賽冠軍
    好,以上就是簡單的理論入門,接下來我們開始著手用TensorFlow對理論進行代碼實現二、代碼實現(ResNet-34)參數設定(DATA_set.py)NUM_LABELS = 10 #對比標籤數量(模型輸出通道)#卷積參數CONV_SIZE = 3CONV_DEEP = 64#學習優化參數BATCH_SIZE = 100LEARNING_RATE_BASE
  • Tensorflow 2.0的這些新設計,你適應好了嗎?
    而就在即將到來的2019年,Tensorflow 2.0將正式入場,給暗流湧動的框架之爭再燃一把火。如果說兩代Tensorflow有什麼根本不同,那應該就是Tensorflow 2.0更注重使用的低門檻,旨在讓每個人都能應用機器學習技術。
  • 教程| 如何用TensorFlow在安卓設備上實現深度學習推斷
    本文詳細介紹了部署和實現過程。對於個人和公司來說,存在許多狀況是更希望在本地設備上做深度學習推斷的:想像一下當你在旅行途中沒有可靠的網際網路連結時,或是要處理傳輸數據到雲服務的隱私問題和延遲問題時。邊緣計算(Edge computing)是一種在物理上靠近數據生成的位置從而對數據進行處理和分析的方法,為解決這些問題提供了方案。
  • 入門| Tensorflow實戰講解神經網絡搭建詳細過程
    二 、網絡結構的設計  接下來通過Tensorflow代碼,實現MINIST手寫數字識別的過程。首先,如程序1所示,我們導入程序所需要的庫函數、數據集:  程序1:  import tensorflow as tf  from tensorflow.examples.tutorials.mnist import input_data  接下來,我們讀取MNIST數據集,並指定用one_hot的編碼方式;然後定義batch_size、batch_num
  • 機器之心GitHub項目:從零開始用TensorFlow搭建卷積神經網絡
    但是機器之心發過許多詳細解釋的入門文章或教程,因此,我們希望讀者能先了解以下基本概念和理論。當然,本文注重實現,即使對深度學習的基本算法理解不那麼深同樣還是能實現本文所述的內容。conda 環境: activate tensorflow 運行後會變為「(tensorflow) C:Users用戶名>」,然後我們就可以繼續在該 conda 環境內安裝 TensorFlow(本文只使用 CPU 進行訓練,所以可以只安裝 CPU 版):
  • 代碼+實戰:TensorFlow Estimator of Deep CTR——DeepFM/NFM/AFM/...
    經過幾個月的調研,發現目前存在的一些問題:開源的實現基本都是學術界的人在搞,距離工業應用還有較大的鴻溝模型實現大量調用底層 API,各版本實現千差萬別,代碼臃腫難懂,遷移成本較高單機,放到工業場景下跑不動針對存在的問題做了一些探索,摸索出一套可行方案,有以下特性:讀數據採用 Dataset API,支持 parallel and prefetch 讀取通過
  • 分享TensorFlow Lite應用案例
    TF Lite 對於 CNN 類的應用支持較好,目前對於 RNN 的支持尚存在 op 支持不足的缺點。但是考慮到內存消耗和性能方面的提升,Kika 仍然建議投入一部分的研發力量,在移動端考慮採用 TF Lite 做為基於 RNN 深度學習模型的 inference 部署方案。   2.
  • TensorFlow 攜手 NVIDIA,使用 TensorRT 優化 TensorFlow Serving...
    HTTP/REST API at:localhost:8501 …$ curl -o /tmp/resnet/resnet_client.py https://raw.githubusercontent.com/tensorflow/serving/master/tensorflow_serving/example/resnet_client.py
  • TensorFlow 資源大全中文版
    循環神經網絡模型/工程圖片形態轉換 – 無監督圖片形態轉換的實現Show, Attend and Tell算法 -基於聚焦機制的自動圖像生成器Neural Style – Neural Style 算法的TensorFlow實現Pretty Tensor – Pretty Tensor提供了高級別的
  • 使用Tensorflow實現RNN-LSTM的菜鳥指南
    Tensorflow和其他各種庫(Theano,Torch,PyBrain)為用戶提供了設計模型的工具,而沒有深入了解實現神經網絡,優化或反向傳播算法的細節。Danijar概述了組織Tensorflow模型的好方法,您可能希望稍後使用它來整理代碼。出於本教程的目的,我們將跳過這一點,並專注於編寫正常工作的代碼。首先導入所需的包。
  • tensorflow初級必學算子
    在之前的文章中介紹過,tensorflow框架的核心是將各式各樣的神經網絡抽象為一個有向無環圖,圖是由tensor以及tensor變換構成;雖然現在有很多高階API可以讓開發者忽略這層抽象,但對於靈活度要求比較高的算法仍然需要開發者自定義網絡圖,所以建議開發者儘量先學習tf1.x
  • 步履不停:TensorFlow 2.4新功能一覽!
    集合運算是 TensorFlow 圖表中的單個算子,可以根據硬體、網絡拓撲和張量大小在 TensorFlow 運行時中自動選擇 All Reduce 算法。集合運算還可實現其他集合運算,例如廣播和 All Gather。
  • 從系統和代碼實現角度解析TensorFlow的內部實現原理|深度
    本文依據對Tensorflow(簡稱TF)白皮書[1]、TF Github[2]和TF官方教程[3]的理解,從系統和代碼實現角度講解TF的內部實現原理。以Tensorflow r0.8.0為基礎,本文由淺入深的闡述Tensor和Flow的概念。先介紹了TensorFlow的核心概念和基本概述,然後剖析了OpKernels模塊、Graph模塊、Session模塊。1.
  • 深度學習筆記8:利用Tensorflow搭建神經網絡
    個人公眾號:數據科學家養成記 (微信ID:louwill12)前文傳送門:深度學習筆記1:利用numpy從零搭建一個神經網絡深度學習筆記2:手寫一個單隱層的神經網絡深度學習筆記3:手動搭建深度神經網絡(DNN)深度學習筆記4:深度神經網絡的正則化深度學習筆記5:正則化與