pytorch編程之空間變換

2021-03-01 pytorch玩轉深度學習

在本教程中,您將學習如何使用稱為空間變換器網絡的視覺注意力機制來擴充網絡。您可以在 DeepMind 論文中詳細了解空間變壓器網絡。

空間變換器網絡是對任何空間變換的可區別關注的概括。空間變換器網絡(簡稱 STN)允許神經網絡學習如何對輸入圖像執行空間變換,以增強模型的幾何不變性。例如,它可以裁剪感興趣的區域,縮放並校正圖像的方向。這可能是一個有用的機制,因為 CNN 不會對旋轉和縮放以及更一般的仿射變換保持不變。

關於 STN 的最好的事情之一就是能夠將它簡單地插入到任何現有的 CNN 中。

# License: BSD
# Author: Ghassen Hamrouni

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

plt.ion() # interactive modeCopy

加載數據

在本文中,我們將嘗試使用經典的 MNIST 數據集。使用標準卷積網絡和空間變換器網絡。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training dataset
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)Copy

出:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!Copy

描述空間變壓器網絡

空間變壓器網絡可歸結為三個主要組成部分:

Note

我們需要包含 affine_grid 和 grid_sample 模塊的最新版本的 PyTorch。

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)

# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)

# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)

grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)

return x

def forward(self, x):
# transform the input
x = self.stn(x)

# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)

model = Net().to(device)Copy

訓練模型

現在,讓我們使用 SGD 算法訓練模型。網絡正在以監督方式學習分類任務。同時,該模型以端到端的方式自動學習 STN。

optimizer = optim.SGD(model.parameters(), lr=0.01)

def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)

optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 500 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100\. * batch_idx / len(train_loader), loss.item()))
#
# A simple test procedure to measure STN the performances on MNIST.
#

def test():
with torch.no_grad():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)

# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(test_loss, correct, len(test_loader.dataset),
100\. * correct / len(test_loader.dataset)))Copy

可視化 STN 結果

現在,我們將檢查學習到的視覺注意力機制的結果。

我們定義了一個小的輔助函數,以便在訓練時可視化轉換。

def convert_image_np(inp):
"""Convert a Tensor to numpy image."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
return inp

# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.

def visualize_stn():
with torch.no_grad():
# Get a batch of training data
data = next(iter(test_loader))[0].to(device)

input_tensor = data.cpu()
transformed_input_tensor = model.stn(data).cpu()

in_grid = convert_image_np(
torchvision.utils.make_grid(input_tensor))

out_grid = convert_image_np(
torchvision.utils.make_grid(transformed_input_tensor))

# Plot the results side-by-side
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(in_grid)
axarr[0].set_title('Dataset Images')

axarr[1].imshow(out_grid)
axarr[1].set_title('Transformed Images')

for epoch in range(1, 20 + 1):
train(epoch)
test()

# Visualize the STN transformation on some input batch
visualize_stn()

plt.ioff()
plt.show()Copy

Out:

Train Epoch: 1 [0/60000 (0%)] Loss: 2.312544
Train Epoch: 1 [32000/60000 (53%)] Loss: 0.865688

Test set: Average loss: 0.2105, Accuracy: 9426/10000 (94%)

Train Epoch: 2 [0/60000 (0%)] Loss: 0.528199
Train Epoch: 2 [32000/60000 (53%)] Loss: 0.273284

Test set: Average loss: 0.1150, Accuracy: 9661/10000 (97%)

Train Epoch: 3 [0/60000 (0%)] Loss: 0.312562
Train Epoch: 3 [32000/60000 (53%)] Loss: 0.496166

Test set: Average loss: 0.1130, Accuracy: 9661/10000 (97%)

Train Epoch: 4 [0/60000 (0%)] Loss: 0.346181
Train Epoch: 4 [32000/60000 (53%)] Loss: 0.206084

Test set: Average loss: 0.0875, Accuracy: 9730/10000 (97%)

Train Epoch: 5 [0/60000 (0%)] Loss: 0.351175
Train Epoch: 5 [32000/60000 (53%)] Loss: 0.388225

Test set: Average loss: 0.0659, Accuracy: 9802/10000 (98%)

Train Epoch: 6 [0/60000 (0%)] Loss: 0.122667
Train Epoch: 6 [32000/60000 (53%)] Loss: 0.258372

Test set: Average loss: 0.0791, Accuracy: 9759/10000 (98%)

Train Epoch: 7 [0/60000 (0%)] Loss: 0.190197
Train Epoch: 7 [32000/60000 (53%)] Loss: 0.154468

Test set: Average loss: 0.0647, Accuracy: 9791/10000 (98%)

Train Epoch: 8 [0/60000 (0%)] Loss: 0.121149
Train Epoch: 8 [32000/60000 (53%)] Loss: 0.288490

Test set: Average loss: 0.0583, Accuracy: 9821/10000 (98%)

Train Epoch: 9 [0/60000 (0%)] Loss: 0.244609
Train Epoch: 9 [32000/60000 (53%)] Loss: 0.023396

Test set: Average loss: 0.0685, Accuracy: 9778/10000 (98%)

Train Epoch: 10 [0/60000 (0%)] Loss: 0.256878
Train Epoch: 10 [32000/60000 (53%)] Loss: 0.091626

Test set: Average loss: 0.0684, Accuracy: 9783/10000 (98%)

Train Epoch: 11 [0/60000 (0%)] Loss: 0.181910
Train Epoch: 11 [32000/60000 (53%)] Loss: 0.113193

Test set: Average loss: 0.0492, Accuracy: 9856/10000 (99%)

Train Epoch: 12 [0/60000 (0%)] Loss: 0.081072
Train Epoch: 12 [32000/60000 (53%)] Loss: 0.082513

Test set: Average loss: 0.0670, Accuracy: 9800/10000 (98%)

Train Epoch: 13 [0/60000 (0%)] Loss: 0.180748
Train Epoch: 13 [32000/60000 (53%)] Loss: 0.194512

Test set: Average loss: 0.0439, Accuracy: 9874/10000 (99%)

Train Epoch: 14 [0/60000 (0%)] Loss: 0.099560
Train Epoch: 14 [32000/60000 (53%)] Loss: 0.084377

Test set: Average loss: 0.0416, Accuracy: 9880/10000 (99%)

Train Epoch: 15 [0/60000 (0%)] Loss: 0.070021
Train Epoch: 15 [32000/60000 (53%)] Loss: 0.241336

Test set: Average loss: 0.0588, Accuracy: 9820/10000 (98%)

Train Epoch: 16 [0/60000 (0%)] Loss: 0.060536
Train Epoch: 16 [32000/60000 (53%)] Loss: 0.053016

Test set: Average loss: 0.0405, Accuracy: 9877/10000 (99%)

Train Epoch: 17 [0/60000 (0%)] Loss: 0.207369
Train Epoch: 17 [32000/60000 (53%)] Loss: 0.069607

Test set: Average loss: 0.1006, Accuracy: 9685/10000 (97%)

Train Epoch: 18 [0/60000 (0%)] Loss: 0.127503
Train Epoch: 18 [32000/60000 (53%)] Loss: 0.070724

Test set: Average loss: 0.0659, Accuracy: 9814/10000 (98%)

Train Epoch: 19 [0/60000 (0%)] Loss: 0.176861
Train Epoch: 19 [32000/60000 (53%)] Loss: 0.116980

Test set: Average loss: 0.0413, Accuracy: 9871/10000 (99%)

Train Epoch: 20 [0/60000 (0%)] Loss: 0.146933
Train Epoch: 20 [32000/60000 (53%)] Loss: 0.245741

Test set: Average loss: 0.0346, Accuracy: 9892/10000 (99%)

相關焦點

  • pytorch專題前言 | 為什麼要學習pytorch?
    >1.生物學科的朋友需要學編程麼?2.為什麼要學習pytorch呢?3.學習了pytorch我怎麼應用呢?4.按照什麼順序去學習pytorch呢?5.網上那麼多資料如何選擇呢?現在開始逐一的對以上問題提出自己的看法,可能想的不夠周全,歡迎討論區一起探討!1.生物學科的朋友需要學編程麼?需要!
  • PyTorch 0.2發布:更多NumPy特性,高階梯度、分布式訓練等
    論文地址:https://arxiv.org/abs/1706.02677distributed包遵循MPI風格編程模型,這意味著可以通過send、recv、all_reduce等函數在節點之間交換tensor。
  • 【Pytorch】pytorch權重初始化方式與原理
    來自 | 知乎地址 | https://zhuanlan.zhihu.com/p/100937718作者 | 機器學習入坑者編輯 | 機器學習算法與自然語言處理公眾號本文僅作學術分享,若侵權,請聯繫後臺刪文處理pytorch
  • PyTorch 1.0 正式版發布了!
    選自code.fb作者:ZACH DEVITO、YANGQING JIA、DMYTRO DZHULGAKOV、SOUMITH CHINTALA、JOSEPH SPISAK機器之心編譯,「我們在 PyTorch1.0 發布前解決了幾大問題,包括可重用、性能、程式語言和可擴展性。」
  • PyTorch 中的傅立葉卷積
    連結如下:https://github.com/fkodom/fft-conv-pytorch卷積在數據分析中無處不在。幾十年來,它們一直被用於信號和圖像處理。最近,它們成為現代神經網絡的重要組成部分。如果你處理數據的話,你可能會遇到錯綜複雜的問題。
  • 深度學習大講堂之pytorch入門
    今天小天就帶大家從數據操作、自動求梯度和神經網絡設計的pytorch版本三個方面來入門pytorch。1.2.4 線性代數官方文檔:https://pytorch.org/docs/stable/torch.html
  • 大家心心念念的PyTorch Windows官方支持來了
    GitHub 發布地址:https://github.com/pytorch/pytorch/releasesPyTorch 官網:http://pytorch.org/機器之心也嘗試在 Windows 安裝簡單的 CPU 版,如下所示我們使用 pip 可以非常輕鬆而流暢地安裝 PyTorch。但當前使用 Conda 安裝 PyTorch 會遇到一些問題,例如小編的 Conda 會報錯說找不到對應的包。
  • Github 2.2K星的超全PyTorch資源列表
    來自|Github作者|bharathgs編譯|機器之心(禁止二次轉載)一份極棒的 PyTorch
  • 庫、教程、論文實現,這是一份超全的PyTorch資源列表(Github 2.2K星)
    機器之心發現了一份極棒的 PyTorch 資源列表,該列表包含了與 PyTorch 相關的眾多庫、教程與示例、論文實現以及其他資源。
  • 【乾貨】史上最全的PyTorch學習資源匯總
    該文檔的官網:http://pytorchchina.com 。此github存儲庫包含兩部分:    o torchText.data:文本的通用數據加載器、抽象和迭代器(包括詞彙和詞向量)    o torchText.datasets:通用NLP數據集的預訓練加載程序 我們只需要通過pip install torchtext安裝好torchtext後,便可以開始體驗Torchtext 的種種便捷之處。
  • 《PyTorch中文手冊》來了
    由於其靈活、動態的編程環境和用戶友好的界面,PyTorch 是快速實驗的理想選擇。PyTorch 現在是 GitHub 上增長速度第二快的開源項目,在過去的 12 個月裡,貢獻者增加了 2.8 倍。而且,去年 12 月在 NeurIPS 大會上,PyTorch 1.0 穩定版終於發布。
  • 深度學習框架搭建之PyTorch
    深度學習框架搭建之PyTorchPyTorch 簡介PyTorch 是由 Facebook 推出,目前一款發展與流行勢頭非常強勁的深度學習框架。情感分析:https://github.com/bentrevett/pytorch-sentiment-analysis智能問答  Question Answering:https://github.com/allenai/document-qa機器翻譯  Translation: OpenNMT-py
  • 基於hough變換的直線檢測
    今天介紹的是Hough變換的直線檢測。首先必須要看原理,何為hough變換。     Hough變換的定義 : Hough變換的基本原理是將影像空間中的曲線變換到參數空間中,通過檢測參數空間中的極值點,確定出該曲線的描述參數,從而提取影像中的規則曲線。
  • 新手必備 | 史上最全的PyTorch學習資源匯總
    此github存儲庫包含兩部分:我們只需要通過pip install torchtext安裝好torchtext後,便可以開始體驗Torchtext 的種種便捷之處。(2)Pytorch-Seq2seq(https://github.com/IBM/pytorch-seq2seq):Seq2seq是一個快速發展的領域,新技術和新框架經常在此發布。
  • 深度學習自救指南(一)| Anaconda、PyTorch的下載和安裝
    選擇Anaconda的一個非常重要的原因是,我們在Anaconda中可以對不同的編程環境進行管理,比如在根據不同的項目需求選擇不同的python版本,簡化環境搭建過程,大大提高了開發效率。同時Anaconda中還附帶了兩個非常好用的交互式代碼編輯器(Spyder、Jupyter notebook)。
  • PyTorch 重大更新,0.4.0 版本支持 Windows 系統
    Tensor/Variable 合併零維張量dtypes遷移指導新特性Tensor全面支持高級索引快速傅立葉變換Github 連結:https://github.com/pytorch/pytorch/releases/tag/v0.4.0PyTorch 官網連結:http://pytorch.org/相關文章:PyTorch 團隊發表周年感言:感謝日益壯大的社群,這一年迎來六大核心突破迎來 PyTorch
  • 新版PyTorch 1.2 已發布:功能更多、兼容更全、操作更快!
    有關完整的 PyTorch 1.2 發行說明,請參見此處(https://github.com/pytorch/pytorch/releases)。對於大小的名稱,我們用前綴 n_(例如「大小(n_freq,n_mel)的張量」)命名,而維度名稱則不具有該前綴(例如「維度張量(通道,時間)」);並且所有變換和函數的輸入我們現在首先要假定通道。這樣做是為了與 PyTorch 保持一致,PyTorch 具有後跟樣本數量的通道,而且這個通道參數目前不推薦使用所有的轉換和函數。
  • 傅立葉變換、拉普拉斯變換、Z 變換的聯繫是什麼?為什麼要進行這些變換?
    要理解這些變換,首先需要理解什麼是數學變換!如果不理解什麼是數學變換的概念,那麼其他的概念我覺得也沒有理解。數學變換是指數學函數從原向量空間在自身函數空間變換,或映射到另一個函數空間,或對於集合X到其自身(比如線性變換)或從X到另一個集合Y的可逆變換函數。
  • PyTorch 中文教程最新版
    教程連結:http://pytorch123.com/作者:磐創 AI       翻譯小組: News & PanChuang原文:https://pytorch.org/tutorials/目錄第一章:PyTorch 之簡介與下載1.
  • 【Pytorch】PyTorch的4分鐘教程,手把手教你完成線性回歸
    下文出現的所有功能函數,均可以在中文文檔中查看具體參數和實現細節,先附上pytorch中文文檔連結:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/