這篇文章整理自我的知乎回答(id: Hanson),分別對深度學習中的多個loss如何平衡 以及 有哪些「魔改」損失函數,曾經拯救了你的深度學習模型 這兩個問題進行了解答。
1. 深度學習的多個loss如何平衡?1.1 mtcnn對於多任務學習而言,它每一組loss之間的數量級和學習難度並不一樣,尋找平衡點是個很難的事情。我舉兩個我在實際應用中碰到的問題。第一個是多任務學習算法MTCNN,這算是人臉檢測領域最經典的算法之一,被各家廠商魔改,其性能也是很不錯的,也有很多版本的開源實現(如果不了解的話,請看:https://blog.csdn.net/qq_36782182/article/details/83624357)。但是我在測試各種實現的過程中,發現竟然沒有一套實現是超越了原版的(https://github.com/kpzhang93/MTCNN_face_detection_alignment)。下圖中是不同版本的實現,打了碼的是我復現的結果。
不同版本mtcnn在FDDB上roc曲線這是一件很困擾的事情,參數、網絡結構大家設置都大差不差。但效果確實是迥異。
clsloss表示置信度score的loss,boxloss表示預測框位置box的loss,landmarksloss表示關鍵點位置landmarks的loss。
那麼
其實有個比較不錯的主意,就是只保留必要的那兩組權值,把另外一組設置為0,比如
就比如這個MTCNN中的ONet,它回歸了包括score、bbox、landmarks,我在用pytorch復現的時候,出現一些有意思的情況,就是將landmarks這條任務凍結後(即
但是加上landmarks任務後(
上面這個實驗意在說明,要存在就好的loss權重組合,那麼你的網絡結構就必須設計的足夠好。不然你可能還需要通過上述的實驗就驗證你的網絡結構。從多種策略的設計上去解決這種loss不均衡造成的困擾。
和@葉不知(知乎用戶)討論後,有一篇論文也可以提供參考:
https://arxiv.org/abs/1810.04002
1.2 ocr-table-ssd第二個是我之前做過一點OCR方面的工作,這個是我對於表格框格式化方面做的工作,基本算原創工作。
https://github.com/hanson-young/ocr-table-ssd改進版本的SSD表格檢測
算法是基於SSD改的,與原有SSD相比增加了一個預測heatmap的分支,算是一種attention機制的表現吧。改進後訓練達到相同的精度和loss,SSD用時10小時,改進後的方法耗時僅需10-20min。在訓練過程中如果兩個分支一起訓練,很難發揮網絡的真正意義,並且收斂到不是很理想的地方,所以訓練過程也挺重要的,在實驗中,將原來的optimizer從SGD(不易收斂,可能和學習率有關)換到RMSProp:
先凍結SSD網絡,然後訓練segmentation分支,到收斂再凍結segmentation分支進行SSD部分的訓練,到收斂原圖預測結果heatmap因為表格尺度的影響,不加入heatmap分支會導致圖像被過分拉升,導致無法檢測到表格框。
加入heatmap後還有個好處就是為表格的對齊提供了可能。
原圖
如果直接檢測,對於一個矩形框來說,恐怕是會非常吃力的。如果
heatmap -> 閾值分割 -> Sobel -> HoughLineP -> angle求出表格的傾斜角angle後,可以將原圖和heatmap旋轉統一的angle後concatenation,這樣再接著跑SSD,對齊後的效果比較明顯,解決了傾斜角度過大,帶來bbox框過大的影響,詳細見下圖。
可以求出角度
然後進行對齊工作
對齊後的結果
是不是能好很多。
2. 有哪些「魔改」損失函數,曾經拯救了你的深度學習模型?我在做缺陷檢測時候對比了一些loss的性能,其實確實是那句話,適合自己的才是最好的。以下我用實際例子來說明這個問題。
2.1 實驗條件為了實驗方便,我們使用了CrackForest數據集(https://github.com/cuilimeng/CrackForest-dataset)做訓練測試,目的是去將裂紋缺陷分割出來,總共118張圖片,其中訓練樣本94張,測試樣本24張,採用旋轉、隨機縮放、裁剪、圖像亮度增強、隨機翻轉增強操作,保證實驗參數一直,模型均是類Unet網絡,僅僅使用了depthwise卷積結構,進行了如下幾組實驗,並在tensorboard中進行圖像預測狀態的觀測。
CrackForest數據集samples2.2 weighted CrossEntropy
在loss函數的選取時,類似focal loss,常規可以嘗試使用cross_entropy_loss_RCF(https://github.com/meteorshowers/RCF-pytorch/blob/master/functions.py),或者是weighted MSE,因為圖像大部分像素為非缺陷區域,只有少部分像素為缺陷裂痕,這樣可以方便解決樣本分布不均勻的問題
validation
epoch[625] | val_loss: 2708.3965 | precisions: 0.2113 | recalls: 0.9663 | f1_scores: 0.3467
training
2018-11-27 11:53:56 [625-0] | train_loss: 2128.9360 | precisions: 0.2554 | recalls: 0.9223 | f1_scores: 0.4000
2018-11-27 11:54:13 [631-2] | train_loss: 1416.9917 | precisions: 0.2359 | recalls: 0.9541 | f1_scores: 0.3782
2018-11-27 11:54:31 [637-4] | train_loss: 1379.9745 | precisions: 0.1916 | recalls: 0.9591 | f1_scores: 0.3194
2018-11-27 11:54:50 [643-6] | train_loss: 1634.6824 | precisions: 0.3067 | recalls: 0.9636 | f1_scores: 0.4654
2018-11-27 11:55:10 [650-0] | train_loss: 2291.4810 | precisions: 0.2498 | recalls: 0.9317 | f1_scores: 0.3940weighted CrossEntropy loss的最佳預測結果
weighted CrossEntropy 在實驗過程中因為圖片中的缺陷部分太過稀疏,導致了weights的選取有很大的問題存在,訓練後會發現其recall極高,但是precision卻也是很低,loss曲線也極其不規律,基本是沒法參考的,能把很多疑似缺陷的地方給弄進來.因此只能將weights改為固定常量,這樣可以在一定程度上控制均衡recall和precision,但調參也會相應變得麻煩
2.3 MSE(不帶權重)我們先來試試MSE,在分割上最常規的loss
validation
epoch[687] | val_loss: 0.0063 | precisions: 0.6902 | recalls: 0.6552 | f1_scores: 0.6723 | time: 0
epoch[875] | val_loss: 0.0067 | precisions: 0.6324 | recalls: 0.7152 | f1_scores: 0.6713 | time: 0
epoch[1250] | val_loss: 0.0066 | precisions: 0.6435 | recalls: 0.7230 | f1_scores: 0.6809 | time: 0
epoch[1062] | val_loss: 0.0062 | precisions: 0.6749 | recalls: 0.6835 | f1_scores: 0.6792 | time: 0
training
2018-11-27 15:01:34 [1375-0] | train_loss: 0.0055 | precisions: 0.6867 | recalls: 0.6404 | f1_scores: 0.6627
2018-11-27 15:01:46 [1381-2] | train_loss: 0.0045 | precisions: 0.7223 | recalls: 0.6747 | f1_scores: 0.6977
2018-11-27 15:01:58 [1387-4] | train_loss: 0.0050 | precisions: 0.7336 | recalls: 0.7185 | f1_scores: 0.7259
2018-11-27 15:02:09 [1393-6] | train_loss: 0.0058 | precisions: 0.6719 | recalls: 0.6196 | f1_scores: 0.6447
2018-11-27 15:02:21 [1400-0] | train_loss: 0.0049 | precisions: 0.7546 | recalls: 0.7191 | f1_scores: 0.7364
2018-11-27 15:02:32 [1406-2] | train_loss: 0.0057 | precisions: 0.7286 | recalls: 0.6919 | f1_scores: 0.7098
2018-11-27 15:02:42 [1412-4] | train_loss: 0.0054 | precisions: 0.7850 | recalls: 0.6932 | f1_scores: 0.7363
2018-11-27 15:02:53 [1418-6] | train_loss: 0.0050 | precisions: 0.7401 | recalls: 0.7223 | f1_scores: 0.7311MSE loss的最佳預測結果
MSE在訓練上較cross entropy就比較穩定,在heatmap預測上優勢挺明顯
2.4 weighted MSE(8:1)既然MSE的效果還不錯,那麼是否加權後就更好了呢,其實從我做的實驗效果來看,並不準確,沒想像的那麼好,甚至導致性能下降了
validation
epoch[2000] | val_loss: 11002.3584 | precisions: 0.5730 | recalls: 0.7602 | f1_scores: 0.6535 | time: 1
training
2018-11-27 13:12:44 [2000-0] | train_loss: 7328.5186 | precisions: 0.6203 | recalls: 0.6857 | f1_scores: 0.6514
2018-11-27 13:13:01 [2006-2] | train_loss: 6290.4971 | precisions: 0.5446 | recalls: 0.5346 | f1_scores: 0.5396
2018-11-27 13:13:18 [2012-4] | train_loss: 5887.3525 | precisions: 0.6795 | recalls: 0.6064 | f1_scores: 0.6409
2018-11-27 13:13:36 [2018-6] | train_loss: 6102.1934 | precisions: 0.6613 | recalls: 0.6107 | f1_scores: 0.6350
2018-11-27 13:13:53 [2025-0] | train_loss: 7460.8853 | precisions: 0.6225 | recalls: 0.7137 | f1_scores: 0.6650weighted MSE loss的最佳預測結果
以上loss在性能表現上,MSE > weighted MSE > weighted CrossEntropy,最簡單的卻在該任務上取得了最好的效果,所以我們接下來該做的,就是去懷疑人生了!
歡迎掃碼關注: