語義分割是計算機視覺的一個問題,我們的任務是使用圖像作為輸入,為圖像中的每個像素分配一個類。在語義分割的情況下,我們不關心是否有同一個類的多個實例(對象),我們只是用它們的類別來標記它們。有多種關於不同計算機視覺問題的介紹課程,但用一張圖片可以總結不同的計算機視覺問題:
語義分割在生物醫學圖像分析中有著廣泛的應用:x射線、MRI掃描、數字病理、顯微鏡、內窺鏡等。https://grand-challenge.org/challenges上有許多不同的有趣和重要的問題有待探索。
從技術角度來看,如果我們考慮語義分割問題,對於N×M×3(假設我們有一個RGB圖像)的圖像,我們希望生成對應的映射N×M×k(其中k是類的數量)。有很多架構可以解決這個問題,但在這裡我想談談兩個特定的架構,Unet和Unet++。
有許多關於Unet的評論,它如何永遠地改變了這個領域。它是一個統一的非常清晰的架構,由一個編碼器和一個解碼器組成,前者生成圖像的表示,後者使用該表示來構建分割。每個空間解析度的兩個映射連接在一起(灰色箭頭),因此可以將圖像的兩種不同表示組合在一起。並且它成功了!
接下來是使用一個訓練好的編碼器。考慮圖像分類的問題,我們試圖建立一個圖像的特徵表示,這樣不同的類在該特徵空間可以被分開。我們可以(幾乎)使用任何CNN,並將其作為一個編碼器,從編碼器中獲取特徵,並將其提供給我們的解碼器。據我所知,Iglovikov & Shvets 使用了VGG11和resnet34分別為Unet解碼器以生成更好的特徵和提高其性能。
Unet++是最近對Unet體系結構的改進,它有多個跳躍連接。
根據論文, Unet++的表現似乎優於原來的Unet。就像在Unet中一樣,這裡可以使用多個編碼器(骨幹)來為輸入圖像生成強特徵。
這裡我想重點介紹Unet和Unet++,並比較它們使用不同的預訓練編碼器的性能。為此,我選擇使用胸部x光數據集來分割肺部。這是一個二值分割,所以我們應該給每個像素分配一個類為「1」的概率,然後我們可以二值化來製作一個掩碼。首先,讓我們看看數據。
來自胸片X光數據集的標註數據的例子這些是非常大的圖像,通常是2000×2000像素,有很大的mask,從視覺上看,找到肺不是問題。使用segmentation_models_pytorch庫,我們為Unet和Unet++使用100+個不同的預訓練編碼器。我們做了一個快速的pipeline來訓練模型,使用Catalyst (pytorch的另一個庫,這可以幫助你訓練模型,而不必編寫很多無聊的代碼)和Albumentations(幫助你應用不同的圖像轉換)。
定義數據集和增強。我們將調整圖像大小為256×256,並對訓練數據集應用一些大的增強。import albumentations as A
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict
class ChestXRayDataset(Dataset):
def __init__(
self,
images,
masks,
transforms):
self.images = images
self.masks = masks
self.transforms = transforms
def __len__(self):
return(len(self.images))
def __getitem__(self, idx):
"""Will load the mask, get random coordinates around/with the mask,
load the image by coordinates
"""
sample_image = imread(self.images[idx])
if len(sample_image.shape) == 3:
sample_image = sample_image[..., 0]
sample_image = np.expand_dims(sample_image, 2) / 255
sample_mask = imread(self.masks[idx]) / 255
if len(sample_mask.shape) == 3:
sample_mask = sample_mask[..., 0]
augmented = self.transforms(image=sample_image, mask=sample_mask)
sample_image = augmented['image']
sample_mask = augmented['mask']
sample_image = sample_image.transpose(2, 0, 1) # channels first
sample_mask = np.expand_dims(sample_mask, 0)
data = {'features': torch.from_numpy(sample_image.copy()).float(),
'mask': torch.from_numpy(sample_mask.copy()).float()}
return(data)
def get_valid_transforms(crop_size=256):
return A.Compose(
[
A.Resize(crop_size, crop_size),
],
p=1.0)
def light_training_transforms(crop_size=256):
return A.Compose([
A.RandomResizedCrop(height=crop_size, width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
], p=1.0),
])
def medium_training_transforms(crop_size=256):
return A.Compose([
A.RandomResizedCrop(height=crop_size, width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
], p=1.0),
A.OneOf(
[
A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
A.NoOp()
], p=1.0),
])
def heavy_training_transforms(crop_size=256):
return A.Compose([
A.RandomResizedCrop(height=crop_size, width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
], p=1.0),
A.ShiftScaleRotate(p=0.75),
A.OneOf(
[
A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
A.NoOp()
], p=1.0),
])
def get_training_trasnforms(transforms_type):
if transforms_type == 'light':
return(light_training_transforms())
elif transforms_type == 'medium':
return(medium_training_transforms())
elif transforms_type == 'heavy':
return(heavy_training_transforms())
else:
raise NotImplementedError("Not implemented transformation configuration")
定義模型和損失函數。這裡我們使用帶有regnety_004編碼器的Unet++,並使用RAdam + Lookahed優化器使用DICE + BCE損失之和進行訓練。import torch
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from catalyst import dl, metrics, core, contrib, utils
import torch.nn as nn
from skimage.io import imread
import os
from sklearn.model_selection import train_test_split
from catalyst.dl import CriterionCallback, MetricAggregationCallback
encoder = 'timm-regnety_004'
model = smp.UnetPlusPlus(encoder, classes=1, in_channels=1)
#model.cuda()
learning_rate = 5e-3
encoder_learning_rate = 5e-3 / 10
layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
base_optimizer = contrib.nn.RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = contrib.nn.Lookahead(base_optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=10)
criterion = {
"dice": DiceLoss(mode='binary'),
"bce": nn.BCEWithLogitsLoss()
}callbacks = [
# Each criterion is calculated separately.
CriterionCallback(
input_key="mask",
prefix="loss_dice",
criterion_key="dice"
),
CriterionCallback(
input_key="mask",
prefix="loss_bce",
criterion_key="bce"
),
# And only then we aggregate everything into one loss.
MetricAggregationCallback(
prefix="loss",
mode="weighted_sum",
metrics={
"loss_dice": 1.0,
"loss_bce": 0.8
},
),
# metrics
IoUMetricsCallback(
mode='binary',
input_key='mask',
)
]
runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
callbacks=callbacks,
logdir='../logs/xray_test_log',
num_epochs=100,
main_metric="loss",
minimize_metric=True,
verbose=True,
)如果我們用不同的編碼器對Unet和Unet++進行驗證,我們可以看到每個訓練模型的驗證質量,並總結如下:
Unet和Unet++驗證集分數我們注意到的第一件事是,在所有編碼器中,Unet++的性能似乎都比Unet好。當然,有時這種差異並不是很大,我們不能說它們在統計上是否完全不同 —— 我們需要在多個folds上訓練,看看分數分布,單點不能證明任何事情。第二,resnest200e顯示了最高的質量,同時仍然有合理的參數數量。有趣的是,如果我們看看https://paperswithcode.com/task/semantic-segmentation,我們會發現resnest200在一些基準測試中也是SOTA。
好的,但是讓我們用Unet++和Unet使用resnest200e編碼器來比較不同的預測。
Unet和Unet++使用resnest200e編碼器的預測。左圖顯示了兩種模型的預測差異在某些個別情況下,Unet++實際上比Unet更糟糕。但總的來說似乎更好一些。
一般來說,對於分割網絡來說,這個數據集看起來是一個容易的任務。讓我們在一個更難的任務上測試Unet++。為此,我使用PanNuke數據集,這是一個帶標註的組織學數據集(205,343個標記核,19種不同的組織類型,5個核類)。數據已經被分割成3個folds。
PanNuke樣本的例子我們可以使用類似的代碼在這個數據集上訓練Unet++模型,如下所示:
驗證集上的Unet++得分我們在這裡看到了相同的模式 - resnest200e編碼器似乎比其他的性能更好。我們可以用兩個不同的模型(最好的是resnest200e編碼器,最差的是regnety_002)來可視化一些例子。
resnest200e和regnety_002的預測我們可以肯定地說,這個數據集是一項更難的任務 —— 不僅mask不夠精確,而且個別的核被分配到錯誤的類別。然而,使用resnest200e編碼器的Unet++仍然表現很好。
總結這不是一個全面語義分割的指導,這更多的是一個想法,使用什麼來獲得一個堅實的基線。有很多模型、FPN,DeepLabV3, Linknet與Unet有很大的不同,有許多Unet-like架構,例如,使用雙編碼器的Unet,MAnet,PraNet,U²-net — 有很多的型號供你選擇,其中一些可能在你的任務上表現的比較好,但是,一個堅實的基線可以幫助你從正確的方向上開始。