小白學PyTorch | 5 torchvision預訓練模型與數據集全覽

2021-02-20 天池大數據科研平臺

文章目錄:

1 torchvision.datssets

2 torchvision.models

模型比較

本文建議複製代碼去跑跑看,增加一下手感。公眾號回復【torchvision】獲取代碼和數據。

torchvision

官網上的介紹:The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.

翻譯過來就是:torchvision包由流行的數據集、模型體系結構和通用的計算機視覺圖像轉換組成。簡單地說就是常用數據集+常見模型+常見圖像增強方法

這個torchvision中主要有包組成:

1 torchvision.datssets

包含賊多的數據集,包含下面的:

官方說明了:All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively.

翻譯過來就是:每一個數據集的API都是基本相同的。他們都有兩個相同的參數:transform和target_transform(後面細講)

我們就用最經典最簡單的MNIST手寫數字數據集作為例子,先看這個的API:

包含5個參數:

root:就是你想要保存MNIST數據集的位置,如果download是Flase的話,則會從目標位置讀取數據集;download:True的話就會自動從網上下載這個數據集,到root的位置;train:True的話,數據集下載的是訓練數據集;False的話則下載測試數據集(真方便,都不用自己劃分了)transform:這個是對圖像進行處理的transform,比方說旋轉平移縮放,輸入的是PIL格式的圖像(不是tensor矩陣);target_transform:這個是對圖像標籤進行處理的函數(這個我沒用過不太確定,也許是做標籤平滑那種的處理?)

【下面用代碼進一步理解】

import torchvision
mydataset = torchvision.datasets.MNIST(root='./',
                                      train=True,
                                      transform=None,
                                      target_transform=None,
                                      download=True)

運行結果如下,表示下載完畢。

之後我們需要用到上一節課講到的dataloader的內容:

from torch.utils.data import Dataset,DataLoader
myloader = DataLoader(dataset=mydataset,
                     batch_size=16)
for i,(data,label) in enumerate(myloader):
    print(data.shape)
    print(label.shape)
    break

這時候會拋出一個錯誤:

大致看一看,就是pytorch的這個dataloader不是可以把數據集分成batch嘛,這個dataloder只能把tensor或者numpy這樣的組合成batch,而現在的數據集的格式是PIL格式。這裡驗證了之前說到的,transform這個輸入是PIL格式的圖片,解決方法是:transform不能是None,我們需要將PIL轉化成tensor才可以

所以我們把上面的transform稍作修改:

mydataset = torchvision.datasets.MNIST(root='./',
                                      train=True,        
                                      transform=torchvision.transforms.ToTensor(),
                                      target_transform=None,
                              ![](https://p3-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/071a7b749c094d30b482c29f16f8ec08~tplv-k3u1fbpfcp-zoom-1.image)        download=True)

重新運行的時候可以得到結果:結果中,16表示一個batch有16個樣本,1表示這是單通道的灰度圖片,28表示MNIST數據集圖片是

想要獲取其他的數據集也是一樣的,不過這裡就用MNIST作為舉例,其他的相同。

2 torchvision.models

預訓練模型中torchvision提供了很多種,大體分成下面四類:

分別是分類模型,語義模型,目標檢測模型和視頻分類模型。這裡呢因為分類模型比較常見也比較基礎,就主要介紹這個好啦。

在torch1.6.0版本中(應該是比較近的版本),主要包含下面的預訓練模型:

構建模型可以通過下面的代碼:

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()

這樣構建的模型的權重值是隨機的,只有結構是保存的。想要獲取預訓練的模型,則需要設置參數pretrained:

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)

我看官網的英文講解,提到了一點:似乎這些模型的預訓練數據集都是ImageNet的那個數據集,輸入圖片都是3通道的,並且要求輸入圖片的寬高不小於224像素,並且要求輸入圖片像素值的範圍在0到1之間,然後做一個normalization標準化。

不知道各位在看一些案例的時候,有沒有看到這個標準化:mean = [0.485, 0.456, 0.406] 和 std = [0.229, 0.224, 0.225],這個應該是ImageNet的圖片的標準化的參數。

這些預訓練的模型參數不確定能不能直接下載,我也就把這些模型存起來一併放在了公眾號的後臺,依然是回復【torchvision】獲取。

得到了.pth文件之後使用torch.load來加載即可。

# torch.save(model, 'model.pth')
model = torch.load('model.pth')

模型比較

最後呢,torchvision官方提供了一個不同模型在Imagenet 1-crop 的一個錯誤率的比較。可以一起來看看到底哪個模型比較好使。這裡我放了一些常見的模型。。像是Wide ResNet這種變種我就不放了。

網絡Top-1 errorTop-5 errorAlexNet43.4520.91VGG-1130.9811.37VGG-1330.0710.75VGG-1628.419.62VGG-1927.629.12VGG-13 with BN28.459.63VGG-19 with BN25.768.15Resnet-1830.2410.92Resnet-3426.708.58Resnet-5023.857.13Resnet-10122.636.44Resnet-15221.695.94SqueezeNet 1.141.8119.38Densenet-16122.356.2

整體來看,還是Resnet殘差網絡效果好。不過EfficientNet效果更好,不過這個模型在Torchvision中沒有提供,會在之後專門講解和提供代碼模板。(先挖坑)。

點讚+在看+轉發

相關焦點

  • 【小白學PyTorch】5.torchvision預訓練模型與數據集全覽
    翻譯過來就是:torchvision包由流行的數據集、模型體系結構和通用的計算機視覺圖像轉換組成。簡單地說就是常用數據集+常見模型+常見圖像增強方法這個torchvision中主要有包組成: 1 torchvision.datssets 包含賊多的數據集,包含下面的:官方說明了:All the datasets have almost
  • 視覺工具包torchvision重大更新:支持分割模型、檢測模型
    這次,工具包裡增加了許多新模型:做語義分割的,做目標檢測的,做實例分割的……也增加了許多數據集,比如ImageNet,CelebA,Caltech 101等等等等。另外,torchvision還有了不少視覺專用的C++/Cuda算子。
  • PyTorch專欄(八):微調基於torchvision 0.3的目標檢測模型
    微調基於torchvision 0.3的目標檢測模型使用Sequence2Sequence網絡和注意力進行翻譯在這篇文章中,我們將微調在 Penn-Fudan 資料庫中對行人檢測和分割的已預先訓練的 Mask R-CNN 模型。它包含170個圖像和345個行人實例,我們 將用它來說明如何在 torchvision 中使用新功能,以便在自定義數據集上訓練實例分割模型。
  • 小白學PyTorch | 15 TF2實現一個簡單的服裝分類任務
    (附代碼)小白學PyTorch | 5 torchvision預訓練模型與數據集全覽小白學PyTorch | 4 構建模型三要素與權重初始化小白學PyTorch | 3 淺談Dataset和Dataloader
  • 小白學PyTorch | 12 SENet詳解及PyTorch實現
    小白學PyTorch | 8 實戰之MNIST小試牛刀小白學PyTorch | 7 最新版本torchvision.transforms常用API翻譯與講解小白學PyTorch | 6 模型的構建訪問遍歷存儲
  • 【小白學PyTorch】7 最新版本torchvision.transforms常用API翻譯與講解
    之前的課程提到了,在torchvision官方的數據集中,提供的數據是PIL格式的數據,然後我們需要轉成FloatTensor形式的數據。因此這裡圖像增強的處理也分成在PIL圖片上操作的和在FloatTensor張量上操作的兩種。
  • 小白學PyTorch | 17 TFrec文件的創建與讀取
    小白學PyTorch | 8 實戰之MNIST小試牛刀小白學PyTorch | 7 最新版本torchvision.transforms常用API翻譯與講解小白學PyTorch | 6 模型的構建訪問遍歷存儲
  • 【小白學PyTorch】18.TF2構建自定義模型
    6 模型的構建訪問遍歷存儲(附代碼)小白學PyTorch | 5 torchvision預訓練模型與數據集全覽小白學PyTorch | 4 構建模型三要素與權重初始化小白學PyTorch | 3 淺談Dataset和Dataloader
  • 一行代碼即可調用18款主流模型!PyTorch Hub輕鬆解決論文可復現性
    PyTorch Hub包含了一系列與圖像分類、圖像分割、生成以及轉換相關的預訓練模型庫,例如ResNet、BERT、GPT、VGG、PGAN、MobileNet等經典模型,PyTorch Hub試圖以最傻瓜的方式,提高研究工作的復現性。有多簡單呢?
  • 8億參數,刷新ImageNet紀錄:何愷明團隊開源最強ResNeXt預訓練模型
    我頭一次聽說,在更大的預訓練集面前,ImageNet成了微調用的小語料庫。9.4億張圖?誰能做完這麼多計算?所以現在好了,你並不需要做這樣大大大量的計算,可以直接從預訓練的模型開始。更好的是,開源的不止這一個模型。
  • LogME:通用快速準確的預訓練模型評估方法
    它能極大地加速預訓練模型選擇的過程,將衡量單個預訓練模型的時間從50個小時減少到一分鐘,瘋狂提速三千倍!問題描述預訓練模型選擇問題,就是針對用戶給定的數據集,從預訓練模型庫中選擇一個最適合的預訓練模型用於遷移學習。
  • 長文解讀綜述NLP中的預訓練模型(純乾貨)
    最後再加上一層全連接層來適應到具體的任務。預訓練:用u表示每一個token(詞),當設置窗口長度為k,預測句中的第i個詞時,則使用第i個詞之前的k個詞,同時也根據超參數Θ,來預測第i個詞最可能是什麼。GPT在一個8億單詞的語料庫上訓練,12個Decoder層,12個attention頭,隱藏層維度為768。GPT在自然語言推理、分類、問答、對比相似度的多種測評中均超越了之前的模型,且從小數據集如STS-B(約5.7k訓練數據實例)到大數據集(550k訓練數據)都表現優異。甚至通過預訓練,也能實現一些Zero-Shot任務。
  • 史上最強通用NLP模型誕生:狂攬7大數據集最佳紀錄
    在官博介紹了他們訓練的一個大規模無監督NLP模型,可以生成連貫的文本段落,刷新了7大數據集基準,並且能在未經預訓練的情況下,完成閱讀理解、問答、機器翻譯等多項不同的語言建模任務。無需預訓練就能完成多種不同任務且取得良好結果,相當於克服了「災難性遺忘」,簡直可謂深度學習研究者夢寐以求的「通用」模型!
  • 基於關係推理的自監督學習無標記訓練
    背景與挑戰在現代深度學習算法中,對未標記數據的手工標註是其主要局限性之一。為了訓練一個好的模型,我們通常需要準備大量的標記數據。在少數類和數據的情況下,我們可以使用帶有標籤的公共數據集的預訓練模型,並使用你的數據微調最後幾層即可。但是,當你的數據很大時(比如商店中的產品或人的臉,..),很容易遇到問題,並且僅通過幾個可訓練的層就很難學習模型。
  • 如何用PyTorch訓練圖像分類器
    它將介紹如何組織訓練數據,使用預訓練神經網絡訓練模型,然後預測其他圖像。為此,我將使用由Google地圖中的地圖圖塊組成的數據集,並根據它們包含的地形特徵對它們進行分類。我會在另一篇文章中介紹如何使用它(簡而言之:為了識別無人機起飛或降落的安全區域)。但是現在,我只想使用一些訓練數據來對這些地圖圖塊進行分類。下面的代碼片段來自Jupyter Notebook。
  • 60分鐘PyTorch快速教程(二):TORCH.AUTOGRAD簡介
    PyTorch中的神經網絡訓練我們從torchvision中加載一個預先訓練好的resnet8模型,然後生成一些隨機數據表示一個3通道,64 x 64的圖像,圖像的label也是隨機生成的。import torch, torchvisionmodel = torchvision.models.resnet18(pretrained=True)data = torch.rand(1, 3, 64, 64)labels = torch.rand(1, 1000)現在把數據丟到模型中做正向傳播:再計算誤差
  • PyTorch上也有Keras了,訓練模型告別Debug,只需專注數據和邏輯
    魚羊 發自 凹非寺量子位 報導 | 公眾號 QbitAI在開始一個新的機器學習項目時,難免要重新編寫訓練循環,加載模型,分布式訓練……然後在Debug的深淵裡看著時間譁譁流逝,而自己離項目核心還有十萬八千裡。
  • 擴展之Tensorflow2.0 | 19 TF2模型的存儲與載入
    小白學PyTorch | 8 實戰之MNIST小試牛刀小白學PyTorch | 7 最新版本torchvision.transforms常用API翻譯與講解小白學PyTorch | 6 模型的構建訪問遍歷存儲
  • 中文最佳,哈工大訊飛聯合發布全詞覆蓋中文BERT預訓練模型
    而在中文領域,哈工大訊飛聯合實驗室也於昨日發布了基於全詞覆蓋的中文 BERT 預訓練模型,在多個中文數據集上取得了當前中文預訓練模型的最佳水平,效果甚至超過了原版 BERT、ERINE 等中文預訓練模型。基於 Transformers 的雙向編碼表示(BERT)在多個自然語言處理任務中取得了廣泛的性能提升。
  • LeNet-5的Pytorch實現
    下圖顯示了其結構:輸入的二維圖像,先經過兩次卷積層到池化層,再經過全連接層,最後使用softmax分類作為輸出層。 LeNet-5 這個網絡雖然很小,但是它包含了深度學習的基本模塊:卷積層,池化層,全連接層。是其他深度學習模型的基礎, 這裡我們對LeNet-5進行深入分析。同時,通過實例分析,加深對與卷積層和池化層的理解。