PyTorch 學習筆記(五):Finetune和各層定製學習率

2021-03-02 極市平臺

加入極市專業CV交流群,與6000+來自騰訊,華為,百度,北大,清華,中科院等名企名校視覺開發者互動交流!更有機會與李開復老師等大牛群內互動!

同時提供每月大咖直播分享、真實項目需求對接、乾貨資訊匯總,行業技術交流點擊文末「閱讀原文」立刻申請入群~

作者 | 餘霆嵩

來源專欄 | PyTorch學習筆記

本文截取自一個github上千星的火爆教程——《PyTorch 模型訓練實用教程》,教程內容主要為在 PyTorch 中訓練一個模型所可能涉及到的方法及函數的詳解等,本文為作者整理的學習筆記(五),後續會繼續更新這個系列,歡迎關注。

項目代碼:https://github.com/tensor-yu/PyTorch_Tutorial

系列回顧:

我們知道一個良好的權值初始化,可以使收斂速度加快,甚至可以獲得更好的精度。而在實際應用中,我們通常採用一個已經訓練模型的模型的權值參數作為我們模型的初始化參數,也稱之為Finetune,更寬泛的稱之為遷移學習。遷移學習中的Finetune技術,本質上就是讓我們新構建的模型,擁有一個較好的權值初始值。

finetune權值初始化三步曲,finetune就相當於給模型進行初始化,其流程共用三步:

第一步:保存模型,擁有一個預訓練模型; 第二步:加載模型,把預訓練模型中的權值取出來; 第三步:初始化,將權值對應的「放」到新模型中

一、Finetune之權值初始化

在進行finetune之前我們需要擁有一個模型或者是模型參數,因此需要了解如何保存模型。官方文檔中介紹了兩種保存模型的方法,一種是保存整個模型,另外一種是僅保存模型參數(官方推薦用這種方法),這裡採用官方推薦的方法。


第一步:保存模型參數
若擁有模型參數,可跳過這一步。假設創建了一個net = Net(),並且經過訓練,通過以下方式保存:torch.save(net.state_dict(), 'net_params.pkl')


第二步:加載模型
進行三步曲中的第二步,加載模型,這裡只是加載模型的參數:pretrained_dict = torch.load('net_params.pkl')


第三步:初始化
進行三步曲中的第三步,將取到的權值,對應的放到新模型中:首先我們創建新模型,並且獲取新模型的參數字典net_state_dict:net = Net() net_state_dict = net.state_dict() 
接著將pretrained_dict裡不屬於net_state_dict的鍵剔除掉:pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
然後,用預訓練模型的參數字典 對 新模型的參數字典net_state_dict 進行更新:net_state_dict.update(pretrained_dict_1)
最後,將更新了參數的字典 「放」回到網絡中:net.load_state_dict(net_state_dict)

這樣,利用預訓練模型參數對新模型的權值進行初始化過程就做完了。

採用finetune的訓練過程中,有時候希望前面層的學習率低一些,改變不要太大,而後面的全連接層的學習率相對大一些。這時就需要對不同的層設置不同的學習率,下面就介紹如何為不同層配置不同的學習率。

二、不同層設置不同的學習率

在利用pre-trained model的參數做初始化之後,我們可能想讓fc層更新相對快一些,而希望前面的權值更新小一些,這就可以通過為不同的層設置不同的學習率來達到此目的。

為不同層設置不同的學習率,主要通過優化器對多個參數組進行設置不同的參數。所以,只需要將原始的參數組,劃分成兩個,甚至更多的參數組,然後分別進行設置學習率。 這裡將原始參數「切分」成fc3層參數和其餘參數,為fc3層設置更大的學習率。

請看代碼:

ignored_params = list(map(id, net.fc3.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) optimizer = optim.SGD([{'params': base_params},{'params': net.fc3.parameters(), 'lr': 0.001*10}], 0.001, momentum=0.9, weight_decay=1e-4)

第一行+ 第二行的意思就是,將fc3層的參數net.fc3.parameters()從原始參數net.parameters()中剝離出來 base_params就是剝離了fc3層的參數的其餘參數,然後在優化器中為fc3層的參數單獨設定學習率。

optimizer = optim.SGD(.)這裡的意思就是 base_params中的層,用 0.001, momentum=0.9, weight_decay=1e-4 fc3層設定學習率為: 0.001*10

完整代碼位於 :

https://github.com/tensor-yu/PyTorch_Tutorial/blob/master/Code/2_model/2_finetune.py


補充:

挑選出特定的層的機制是利用內存地址作為過濾條件,將需要單獨設定的那部分參數,從總的參數中剔除。 base_params 是一個list,每個元素是一個Parameter 類 net.fc3.parameters() 是一個

ignored_params = list(map(id, net.fc3.parameters())) net.fc3.parameters() 是一個 所以迭代的返回其中的parameter,這裡有weight 和 bias 最終返回weight和bias所在內存的地址

*延伸閱讀

點擊左下角閱讀原文」,即可申請加入極市目標跟蹤、目標檢測、工業檢測、人臉方向、視覺競賽等技術交流群,更有每月大咖直播分享、真實項目需求對接、乾貨資訊匯總,行業技術交流,一起來讓思想之光照的更遠吧~

覺得有用麻煩給個在看啦~  

相關焦點

  • [PyTorch 學習筆記] 7.2 模型 Finetune
    本章代碼:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/finetune_resnet18.py這篇文章主要介紹了模型的 Finetune。
  • 從零開始深度學習Pytorch筆記(12)—— nn.Module
    筆記(2)——張量的創建(上)從零開始深度學習Pytorch筆記(3)——張量的創建(下)從零開始深度學習Pytorch筆記(4)——張量的拼接與切分從零開始深度學習Pytorch筆記(5)——張量的索引與變換
  • 【乾貨】史上最全的PyTorch學習資源匯總
    · 開源書籍(https://github.com/zergtant/pytorch-handbook):這是一本開源的書籍,目標是幫助那些希望和使用PyTorch進行深度學習開發和研究的朋友快速入門。但本文檔不是內容不是很全,還在持續更新中。
  • (二)pytorch學習筆記
    (一)pytorch學習筆記(二)pytorch學習筆記關係擬合 (回歸)我會這次會來見證神經網絡是如何通過簡單的形式將一群數據用一條線條來表示. 或者說, 是如何在數據當中找到他們的關係, 然後用神經網絡模型來建立一個可以代表他們關係的線條.建立數據集我們創建一些假數據來模擬真實的情況.
  • 新手必備 | 史上最全的PyTorch學習資源匯總
    這是一本開源的書籍,目標是幫助那些希望和使用PyTorch進行深度學習開發和研究的朋友快速入門。但本文檔不是內容不是很全,還在持續更新中。(5)最後,為大家推薦一個簡單易上手的PyTorch中文文檔,非常適合新手學習:https://github.com/fendouai/pytorch1.0-cn。
  • pytorch學習筆記(2):在 MNIST 上實現一個 cnn
    在前面我要先說一下,這個系列是為了讓大家對 pytorch 從入門到熟悉,對於 deep learning 相關的知識我們不會花費過多的內容去介紹。如果大家對一些 DL 的基礎相關知識不懂的話,推薦幾個資源去學習:所以我們在筆記中對於一些相關的知識就不做深入介紹了。
  • 【乾貨】基於pytorch的CNN、LSTM神經網絡模型調參小結
    Demo Site:  https://github.com/bamtercelboo/cnn-lstm-bilstm-deepcnn-clstm-in-pytorchPytorch是一個較新的深度學習框架,是一個 Python 優先的深度學習框架,能夠在強大的 GPU 加速基礎上實現張量和動態神經網絡。
  • PyTorch 學習筆記(七):PyTorch的十個優化器
    PyTorch 中訓練一個模型所可能涉及到的方法及函數的詳解等,本文為作者整理的學習筆記(六),後續會繼續更新這個系列,歡迎關注。lr(float)- 初始學習率,可按需隨著訓練過程不斷調整學習率。這個學習率的變化,會受到梯度的大小和迭代次數的影響。梯度越大,學習率越小;梯度越小,學習率越大。缺點是訓練後期,學習率過小,因為Adagrad累加之前所有的梯度平方作為分母。
  • 資料|【乾貨】PyTorch學習資源匯總
    建議大家在閱讀本文檔之前,先學習上述兩個PyTorch基礎教程。開源書籍:這是一本開源的書籍,目標是幫助那些希望和使用PyTorch進行深度學習開發和研究的朋友快速入門。但本文檔不是內容不是很全,還在持續更新中。簡單易上手的PyTorch中文文檔:非常適合新手學習。
  • 【深度學習】textCNN論文與原理
    如果CNN不是很了解的話,可以看看我之前的文章:【深度學習】卷積神經網絡-CNN簡單理論介紹[1] 、 【深度學習】卷積神經網絡-圖片分類案例(pytorch實現)[2],當然既然是一種深度學習方法進行文本分類,跑不了使用詞向量相關內容,所以讀者也是需要有一定詞向量(也就是詞語的一種分布式表示而已)的概念。
  • pytorch專題前言 | 為什麼要學習pytorch?
    2.為什麼要學習pytorch呢?3.學習了pytorch我怎麼應用呢?4.按照什麼順序去學習pytorch呢?5.網上那麼多資料如何選擇呢?現在開始逐一的對以上問題提出自己的看法,可能想的不夠周全,歡迎討論區一起探討!1.生物學科的朋友需要學編程麼?需要!
  • 「fine-tune」別理解成「好的調子」!
    大家好,今天我們分享一個非常有用且地道的表達——fine-tune, 這個短語的含義不是指「好的調子」,其正確的含義是:fine-tune 對…進行微調,調整She spent hours fine-tuning
  • 乾貨| BERT fine-tune 終極實踐教程
    因此對於不同數據集的適配,只需要修改代碼中的processor部分,就能進行代碼的訓練、交叉驗證和測試。以下是奇點機智技術團隊對BERT在中文數據集上的fine tune終極實踐教程。在自己的數據集上運行 BERTBERT的代碼同論文裡描述的一致,主要分為兩個部分。
  • 雲計算學習:用PyTorch實現一個簡單的分類器
    回想了一下自己關於 pytorch 的學習路線,一開始找的各種資料,寫下來都能跑,但是卻沒有給自己體會到學習的過程。有的教程一上來就是寫一個 cnn,雖然其實內容很簡單,但是直接上手容易讓人找不到重點,學的雲裡霧裡。
  • 小樣本學習跨域(Cross-domain)問題總結
    呈現較大幅度下降;在miniimagenet和CUB-miniimagenet兩組實驗中,pre-training + fine-tuning的結果優於各元學習模型,表明在小樣本學習的跨域問題上,元學習方法似乎失去了優勢;二:A Broader Study of Cross-Domain Few-Shot Learning文章連結
  • 深度學習不得不會的遷移學習(Transfer Learning)
    2.3 遷移學習有幾種方式     2.4 三種遷移學習方式的對比三、實驗:嘗試對模型進行微調,以進一步提升模型性能    3.1 Fine-tune所扮演的角色     3.2 Fine-tune 也可以有三種操作方式     3.3 不同數據集下使用微調     3.4
  • 使用resnet, inception3進行fine-tune出現訓練集準確率很高但驗證集很低的問題
    在這篇文章中,我會構建一個案例來說明為什麼Keras的BN層對遷移學習並不友好,並給出對Keras BN層的一個修復補丁,以及修復後的實驗效果。1. Introduction這一節我會簡要介紹遷移學習和BN層,以及learning_phase的工作原理,Keras BN層在各個版本中的變化。
  • Github 2.2K星的超全PyTorch資源列表
    在本文中,我們對各部分資源進行了介紹,感興趣的同學可收藏、查用。其中第 4 個項目可以用於將你的定製圖像分類模型和當前最佳模型進行對比,快速知道你的項目到底有沒有希望,作者戲稱該項目為「Project Killer」。1.pytorch vision:計算機視覺領域的數據集、轉換和模型。
  • 李宏毅老師深度學習與人類語言處理課程視頻及課件(附下載)
    李宏毅老師2020新課 深度學習與人類語言處理課程 昨天(7月10日)終於完結了,這門課程裡語音和文本的內容各佔一半,主要關注近3
  • PyTorch  深度學習新手入門指南
    ,這篇文章是為想開始用pytorch來進行深度學習項目研究的人準備的。如果在領英上,你也許會說自己是一個深度學習的狂熱愛好者,但是你只會用 keras 搭建模型,那麼,這篇文章非常適合你。2. 你可能對理解 tensorflow 中的會話,變量和類等有困擾,並且計劃轉向 pytorch,很好,你來對地方了。3. 如果你能夠用 pytorch 構建重要、複雜的模型,並且現在正在找尋一些實現細節,不好意思,你可以直接跳到最後一部分。