實踐教程 | 解決pytorch半精度amp訓練nan問題

2021-12-28 極市平臺
Why?

如果要解決問題,首先就要明確原因:為什麼全精度訓練時不會nan,但是半精度就開始nan?這其實分了三種情況:

1&2我想放到後面討論,因為其實大部分報nan都是第三種情況。這裡來先看看3。什麼情況下會出現情況3?這個討論給出了不錯的解釋:

【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

給大家翻譯翻譯:在使用ce loss 或者 bceloss的時候,會有log的操作,在半精度情況下,一些非常小的數值會被直接捨入到0,log(0)等於啥?——等於nan啊!

於是邏輯就理通了:回傳的梯度因為log而變為nan->網絡參數nan-> 每輪輸出都變成nan。(;´Д`)

How?

問題定義清楚,那解決方案就非常簡單了,只需要在涉及到log計算時,把輸入從half精度轉回float32:

x = x.float()
x_sigmoid = torch.sigmoid(x)

一些思考&廢話

這裡我接著討論下我第一次看到nan之後,企圖直接copy別人的解決方案,但解決不掉時踩過的坑。比如:

有些blog會建議你從默認的1e-8 改為 1e-3,比如這篇:【pytorch1.1 半精度訓練 Adam RMSprop 優化器 Nan 問題】https://link.zhihu.com/?target=https%3A//blog.csdn.net/gwb281386172/article/details/104705195

經過上面的分析,我們就能知道為什麼這種方法不行——這個方案是針對優化器的數值穩定性做的修改,而loss計算這一步在優化器之前,如果loss直接nan,優化器的eps是救不回來的(託腮)。

那麼這個方案在哪些場景下有效?——在loss輸出不是nan時(感覺說了一句廢話)。optimizer的eps是保證在進行除法backwards時,分母不出現0時需要加上的微小量。在半精度情況下,分母加上1e-8就仿佛聽君一席話,因此,需要把eps調大一點。

GradScaler是autocast的好夥伴,在官方教程上就和autocast配套使用:

from torch.cuda.amp import autocast, GradScaler
...
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()

with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()

scaler.step(optimizer)
scaler.update()

具體原理不是我這篇文章討論的範圍,網上很多教程都說得很清楚了,比如這個就不錯:

【Gemfield:PyTorch的自動混合精度(AMP)】https://zhuanlan.zhihu.com/p/165152789

但是我這裡想討論另一點:scaler.step(optimizer)的運行原理。

在初始化GradScaler的時候,有一個參數enabled,值默認為True。如果為True,那麼在調用scaler方法時會做梯度縮放來調整loss,以防半精度狀況下,梯度值過大或者過小從而被nan或者inf。而且,它還會判斷本輪loss是否是nan,如果是,那麼本輪計算的梯度不會回傳,同時,當前的scale係數乘上backoff_factor,縮減scale的大小_。_

那麼,為什麼這一步已經判斷了loss是不是nan,還是會出現網絡損失持續nan的情況呢?

這時我們就得再往前思考一步了:為什麼loss會變成nan?回到文章一開始說的:

(1)計算loss 時,出現了除以0的情況;

(2)loss過大,被半精度判斷為inf;

(3)網絡直接輸出了nan。

(1)&(2),其實是可以通過scaler.step(optimizer)解決的,分別由optimizer和scaler幫我們捕捉到了nan的異常。但(3)不行,(3)意味著部分甚至全部的網絡參數已經變成nan了。這可能是在更之前的梯度回傳過程中除以0導致的——首先【回傳的梯度不是nan】,所以scaler不會捕捉異常;其次,由於使用了半精度,optimizer接收到了【已經因為精度損失而變為nan的loss】,nan不管加上多大的eps,都還是nan,所以optimizer也無法處理異常,最終導致網絡參數nan。

所以3,只能通過本文一開始提出的方案來解決。其實,大部分分類問題在使用半精度時出現nan的情況都是第3種情況,也只能通過把精度轉回為float32,或者在計算log時加上微小量來避免(但這樣會損失精度)。

參考

【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

如果覺得有用,就請分享到朋友圈吧!

相關焦點

  • 肝了一晚上,總結了Pytorch的訓練秘訣!
    然而,使用大 batch 的不足是,這可能導致解決方案的泛化能力比使用小 batch 的差。PyTorch 1.6 版本包括對 PyTorch 的自動混合精度訓練的本地實現。這裡想說的是,與單精度 (FP32) 相比,某些運算在半精度 (FP16) 下運行更快,而不會損失準確率。AMP 會自動決定應該以哪種精度執行哪種運算。這樣既可以加快訓練速度,又可以減少內存佔用。
  • PyTorch 源碼解讀之 torch.cuda.amp: 自動混合精度詳解
    這樣在不改變模型、不降低模型訓練精度的前提下,可以縮短訓練時間,降低存儲需求,因而能支持更多的 batch size、更大模型和尺寸更大的輸入進行訓練。PyTorch 從 1.6 以後(在此之前 OpenMMLab 已經支持混合精度訓練,即 Fp16OptimizerHook),開始原生支持 amp,即torch.cuda.amp module。
  • 提升PyTorch訓練速度,小哥哥總結了17種方法!
    然而,使用大 batch 的不足是,這可能導致解決方案的泛化能力比使用小 batch 的差。PyTorch 1.6 版本包括對 PyTorch 的自動混合精度訓練的本地實現。這裡想說的是,與單精度 (FP32) 相比,某些運算在半精度 (FP16) 下運行更快,而不會損失準確率。AMP 會自動決定應該以哪種精度執行哪種運算。這樣既可以加快訓練速度,又可以減少內存佔用。
  • 使用AMP和Tensor Cores得到更快速,更節省內存的PyTorch模型
    Tensor cores支持混合精度數學,即以半精度(FP16)進行輸入,以全精度(FP32)進行輸出。上述類型的操作對許多深度學習任務具有內在價值,而Tensor cores為這種操作提供了專門的硬體。現在,使用FP16和FP32主要有兩個好處。FP16需要更少的內存,因此更容易訓練和部署大型神經網絡。
  • PyTorch常見的12坑
    7. pytorch的可重複性問題參考這篇博文:https://blog.csdn.net/hyk_1996/article/details/84307108 8. 多GPU的處理機制使用多GPU時,應該記住pytorch的處理邏輯是:1.在各個GPU上初始化模型。2.前向傳播時,把batch分配到各個GPU上進行計算。
  • 9個讓PyTorch模型訓練提速的技巧
    Lightning是在Pytorch之上的一個封裝,它可以自動訓練,同時讓研究人員完全控制關鍵的模型組件。Lightning 使用最新的最佳實踐,並將你可能出錯的地方最小化。我們為MNIST定義LightningModel並使用Trainer來訓練模型。
  • 【Pytorch】Pytorch多機多卡分布式訓練
    關於Pytorch分布訓練的話,大家一開始接觸的往往是DataParallel,這個wrapper能夠很方便的使用多張卡,而且將進程控制在一個。唯一的問題就在於,DataParallel只能滿足一臺機器上gpu的通信,而一臺機器一般只能裝8張卡,對於一些大任務,8張卡就很吃力了,這個時候我們就需要面對多機多卡分布式訓練這個問題了,噩夢開始了。
  • PyTorch 中文教程最新版
    本文檔的定位是 PyTorch 入門教程,主要針對想要學習 PyTorch 的學生群體或者深度學習愛好者。通過教程的學習,能夠實現零基礎想要了解和學習深度學習,降低自學的難度,快速學習 PyTorch。
  • 【Pytorch】新手如何入門pytorch?
    另外jcjohnson 的Simple examples to introduce PyTorch 也不錯第二步 example 參考 pytorch/examples 實現一個最簡單的例子(比如訓練mnist )。
  • 當代研究生應當掌握的5種Pytorch並行訓練方法(單機多卡)
    /blob/master/dataparallel.py2、使用 torch.distributed 加速並行訓練https://github.com/tczhangzhi/pytorch-distributed/blob/master/distributed.py3、使用 torch.multiprocessing
  • 庫、教程、論文實現,這是一份超全的PyTorch資源列表(Github 2.2K星)
    這些項目有很多是官方的實現,其中 FAIR 居多,一般會有系統的使用說明,包含安裝、加載、訓練、測試、演示等多方面的詳細解釋。例如哈佛大學的 OpenNMT 項目,它是非常流行的神經機器翻譯工具包。從導入自定義數據集、加載詞嵌入向量到完成神經機器翻譯模型的訓練,OpenNMT 能支持整個流程,並且官方也一直在更新。
  • 60分鐘入門PyTorch,官方教程手把手教你訓練第一個深度學習模型(附連結)
    PyTorch 的一份官方教程表示:只需要 60 分鐘。教程連結:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html這是一份非常簡潔的學習材料,目標是讓學習者了解 PyTorch 的 Tensor 庫和神經網絡,以及如何訓練一個可以進行圖像分類的神經網絡。
  • PyTorch1.6:新增自動混合精度訓練、Windows版開發維護權移交微軟
    新版本增加了一個 amp 子模塊,支持本地自動混合精度訓練。Facebook 還表示,微軟已擴大了對 PyTorch 社區的參與,現在擁有 PyTorch 在 Windows 上的開發和維護所有權。相比於以往的 PyTorch 版本,本次即將發布的 PyTorch 1.6 有哪些吸引人的地方呢?
  • 帶你少走彎路:強烈推薦的Pytorch快速入門資料和翻譯(可下載)
    備註:TensorFlow的快速入門資料很負責任地說:看完這些資料,Pytorch基本入門了,接下來碰到問題能自己查資料解決了!本文內容較多,可以在線學習,如果需要本地調試,請到github下載:https://github.com/fengdu78/Data-Science-Notes/tree/master/8.deep-learning/PyTorch_beginner此教程為翻譯官方地址:https://pytorch.org/tutorials/beginner
  • PyTorch 深度學習官方入門中文教程 pdf 下載|PyTorchChina
    官方教程包含了 PyTorch 介紹,安裝教程;60分鐘快速入門教程,可以迅速從小白階段完成一個分類器模型;計算機視覺常用模型,方便基於自己的數據進行調整,不再需要從頭開始寫;自然語言處理模型,聊天機器人,文本生成等生動有趣的項目。總而言之:如果你想了解一下 PyTorch,可以看介紹部分。
  • 【Github 3.5K 星】PyTorch資源列表:450個NLP/CV/SP、論文實現、庫、教程&示例
    這些項目有很多是官方的實現,其中 FAIR 居多,一般會有系統的使用說明,包含安裝、加載、訓練、測試、演示等多方面的詳細解釋。例如哈佛大學的 OpenNMT 項目,它是非常流行的神經機器翻譯工具包。從導入自定義數據集、加載詞嵌入向量到完成神經機器翻譯模型的訓練,OpenNMT 能支持整個流程,並且官方也一直在更新。
  • Pytorch 中文文檔和中文教程
    筆者獲得了ApacheCN社區的同意,放出該社區的翻譯文檔和官方教程,歡迎大家多去GitHub頁面 fork,star!簡單介紹GitHub項目管理:https://github.com/apachecn/pytorch-doc-zh
  • 《PyTorch中文手冊》來了
    這是一本開源的書籍,目標是幫助那些希望和使用 PyTorch 進行深度學習開發和研究的朋友快速入門,其中包含的 Pytorch 教程全部通過測試保證可以成功運行。由於本人水平有限,在寫此教程的時候參考了一些網上的資料,在這裡對他們表示敬意,我會在每個引用中附上原文地址,方便大家參考。
  • PyTorch 官方教程中文版正式上線,激動人心的大好事!
    教程地址:http://pytorch123.com/本文檔主要使用於 PyTorch 入門學者,主要參考 PyTorch 官方文檔。如果你想解決計算機視覺問題,可以看 CV 部分。如果你想解決自然語言處理問題,可以看 NLP 部分。整個教程共包含了 7 部分,內容由簡單到複雜,適合不同層次的學習要求。下面分別進行介紹。1.
  • 【乾貨】史上最全的PyTorch學習資源匯總
    · 比較偏算法實戰的PyTorch代碼教程(https://github.com/yunjey/pytorch-tutorial):在github上有很高的star。建議大家在閱讀本文檔之前,先學習上述兩個PyTorch基礎教程。