醫學圖像分割最佳方法的全面比較:U-Net和U-Net++

2022-01-04 計算機視覺Daily

介紹

語義分割是計算機視覺的一個問題,我們的任務是使用圖像作為輸入,為圖像中的每個像素分配一個類。在語義分割的情況下,我們不關心是否有同一個類的多個實例(對象),我們只是用它們的類別來標記它們。有多種關於不同計算機視覺問題的介紹課程,但用一張圖片可以總結不同的計算機視覺問題:

語義分割在生物醫學圖像分析中有著廣泛的應用:x射線、MRI掃描、數字病理、顯微鏡、內窺鏡等。https://grand-challenge.org/challenges上有許多不同的有趣和重要的問題有待探索。

從技術角度來看,如果我們考慮語義分割問題,對於N×M×3(假設我們有一個RGB圖像)的圖像,我們希望生成對應的映射N×M×k(其中k是類的數量)。有很多架構可以解決這個問題,但在這裡我想談談兩個特定的架構,Unet和Unet++。

有許多關於Unet的評論,它如何永遠地改變了這個領域。它是一個統一的非常清晰的架構,由一個編碼器和一個解碼器組成,前者生成圖像的表示,後者使用該表示來構建分割。每個空間解析度的兩個映射連接在一起(灰色箭頭),因此可以將圖像的兩種不同表示組合在一起。並且它成功了!

接下來是使用一個訓練好的編碼器。考慮圖像分類的問題,我們試圖建立一個圖像的特徵表示,這樣不同的類在該特徵空間可以被分開。我們可以(幾乎)使用任何CNN,並將其作為一個編碼器,從編碼器中獲取特徵,並將其提供給我們的解碼器。據我所知,Iglovikov & Shvets 使用了VGG11和resnet34分別為Unet解碼器以生成更好的特徵和提高其性能。

TernausNet (VGG11 Unet)

Unet++是最近對Unet體系結構的改進,它有多個跳躍連接。

根據論文, Unet++的表現似乎優於原來的Unet。就像在Unet中一樣,這裡可以使用多個編碼器(骨幹)來為輸入圖像生成強特徵。

我應該使用哪個編碼器?

這裡我想重點介紹Unet和Unet++,並比較它們使用不同的預訓練編碼器的性能。為此,我選擇使用胸部x光數據集來分割肺部。這是一個二值分割,所以我們應該給每個像素分配一個類為「1」的概率,然後我們可以二值化來製作一個掩碼。首先,讓我們看看數據。

來自胸片X光數據集的標註數據的例子

這些是非常大的圖像,通常是2000×2000像素,有很大的mask,從視覺上看,找到肺不是問題。使用segmentation_models_pytorch庫,我們為Unet和Unet++使用100+個不同的預訓練編碼器。我們做了一個快速的pipeline來訓練模型,使用Catalyst (pytorch的另一個庫,這可以幫助你訓練模型,而不必編寫很多無聊的代碼)和Albumentations(幫助你應用不同的圖像轉換)。

定義數據集和增強。我們將調整圖像大小為256×256,並對訓練數據集應用一些大的增強。
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict

class ChestXRayDataset(Dataset):
    def __init__(
        self,
        images,
        masks,
            transforms):
        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __len__(self):
        return(len(self.images))

    def __getitem__(self, idx):
        """Will load the mask, get random coordinates around/with the mask,
        load the image by coordinates
        """
        sample_image = imread(self.images[idx])
        if len(sample_image.shape) == 3:
            sample_image = sample_image[..., 0]
        sample_image = np.expand_dims(sample_image, 2) / 255
        sample_mask = imread(self.masks[idx]) / 255
        if len(sample_mask.shape) == 3:
            sample_mask = sample_mask[..., 0]  
        augmented = self.transforms(image=sample_image, mask=sample_mask)
        sample_image = augmented['image']
        sample_mask = augmented['mask']
        sample_image = sample_image.transpose(2, 0, 1)  # channels first
        sample_mask = np.expand_dims(sample_mask, 0)
        data = {'features': torch.from_numpy(sample_image.copy()).float(),
                'mask': torch.from_numpy(sample_mask.copy()).float()}
        return(data)
    
def get_valid_transforms(crop_size=256):
    return A.Compose(
        [
            A.Resize(crop_size, crop_size),
        ],
        p=1.0)

def light_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
    ])

def medium_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.OneOf(
            [
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
    ])


def heavy_training_transforms(crop_size=256):
    return A.Compose([
        A.RandomResizedCrop(height=crop_size, width=crop_size),
        A.OneOf(
            [
                A.Transpose(),
                A.VerticalFlip(),
                A.HorizontalFlip(),
                A.RandomRotate90(),
                A.NoOp()
            ], p=1.0),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf(
            [
                A.CoarseDropout(max_holes=16, max_height=16, max_width=16),
                A.NoOp()
            ], p=1.0),
    ])

def get_training_trasnforms(transforms_type):
    if transforms_type == 'light':
        return(light_training_transforms())
    elif transforms_type == 'medium':
        return(medium_training_transforms())
    elif transforms_type == 'heavy':
        return(heavy_training_transforms())
    else:
        raise NotImplementedError("Not implemented transformation configuration")

定義模型和損失函數。這裡我們使用帶有regnety_004編碼器的Unet++,並使用RAdam + Lookahed優化器使用DICE + BCE損失之和進行訓練。
import torch
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from catalyst import dl, metrics, core, contrib, utils
import torch.nn as nn
from skimage.io import imread
import os
from sklearn.model_selection import train_test_split
from catalyst.dl import  CriterionCallback, MetricAggregationCallback
encoder = 'timm-regnety_004'
model = smp.UnetPlusPlus(encoder, classes=1, in_channels=1)
#model.cuda()
learning_rate = 5e-3
encoder_learning_rate = 5e-3 / 10
layerwise_params = {"encoder*": dict(lr=encoder_learning_rate, weight_decay=0.00003)}
model_params = utils.process_model_params(model, layerwise_params=layerwise_params)
base_optimizer = contrib.nn.RAdam(model_params, lr=learning_rate, weight_decay=0.0003)
optimizer = contrib.nn.Lookahead(base_optimizer)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=10)
criterion = {
    "dice": DiceLoss(mode='binary'),
    "bce": nn.BCEWithLogitsLoss()
}

callbacks = [
    # Each criterion is calculated separately.
    CriterionCallback(
       input_key="mask",
        prefix="loss_dice",
        criterion_key="dice"
    ),
    CriterionCallback(
        input_key="mask",
        prefix="loss_bce",
        criterion_key="bce"
    ),

    # And only then we aggregate everything into one loss.
    MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum", 
        metrics={
            "loss_dice": 1.0, 
            "loss_bce": 0.8
        },
    ),

    # metrics
    IoUMetricsCallback(
        mode='binary', 
        input_key='mask', 
    )
    
]

runner = dl.SupervisedRunner(input_key="features", input_target_key="mask")
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    callbacks=callbacks,
    logdir='../logs/xray_test_log',
    num_epochs=100,
    main_metric="loss",
    minimize_metric=True,
    verbose=True,
)

如果我們用不同的編碼器對Unet和Unet++進行驗證,我們可以看到每個訓練模型的驗證質量,並總結如下:

Unet和Unet++驗證集分數

我們注意到的第一件事是,在所有編碼器中,Unet++的性能似乎都比Unet好。當然,有時這種差異並不是很大,我們不能說它們在統計上是否完全不同 —— 我們需要在多個folds上訓練,看看分數分布,單點不能證明任何事情。第二,resnest200e顯示了最高的質量,同時仍然有合理的參數數量。有趣的是,如果我們看看https://paperswithcode.com/task/semantic-segmentation,我們會發現resnest200在一些基準測試中也是SOTA。

好的,但是讓我們用Unet++和Unet使用resnest200e編碼器來比較不同的預測。

Unet和Unet++使用resnest200e編碼器的預測。左圖顯示了兩種模型的預測差異

在某些個別情況下,Unet++實際上比Unet更糟糕。但總的來說似乎更好一些。

一般來說,對於分割網絡來說,這個數據集看起來是一個容易的任務。讓我們在一個更難的任務上測試Unet++。為此,我使用PanNuke數據集,這是一個帶標註的組織學數據集(205,343個標記核,19種不同的組織類型,5個核類)。數據已經被分割成3個folds。

PanNuke樣本的例子

我們可以使用類似的代碼在這個數據集上訓練Unet++模型,如下所示:

驗證集上的Unet++得分

我們在這裡看到了相同的模式 - resnest200e編碼器似乎比其他的性能更好。我們可以用兩個不同的模型(最好的是resnest200e編碼器,最差的是regnety_002)來可視化一些例子。

resnest200e和regnety_002的預測

我們可以肯定地說,這個數據集是一項更難的任務 —— 不僅mask不夠精確,而且個別的核被分配到錯誤的類別。然而,使用resnest200e編碼器的Unet++仍然表現很好。

總結

這不是一個全面語義分割的指導,這更多的是一個想法,使用什麼來獲得一個堅實的基線。有很多模型、FPN,DeepLabV3, Linknet與Unet有很大的不同,有許多Unet-like架構,例如,使用雙編碼器的Unet,MAnet,PraNet,U²-net — 有很多的型號供你選擇,其中一些可能在你的任務上表現的比較好,但是,一個堅實的基線可以幫助你從正確的方向上開始。

相關焦點

  • 圖像分割的U-Net系列方法
    來源丨https://zhuanlan.zhihu.com/p/57530767僅作學術交流,如有侵權,請聯繫刪文在圖像分割任務特別是醫學圖像分割中
  • 圖像分割中的深度學習:U-Net 體系結構
    user=vZA2pjwAAAAJ&hl=en)U-Net是一種卷積神經網絡(CNN)方法,由Olaf Ronneberger、Phillip Fischer和Thomas Brox於2015年首次提出,它可以更好的分割生物醫學圖像。
  • U-Net 和 ResNet:長短跳躍連接的重要性(生物醫學圖像分割)
    這次,我們來聊一聊用於生物醫學圖像分割的的一種全卷積神經網絡,這個網絡帶有長短跳躍連接。在RoR中,通過使用長短跳躍連接,圖像分類準確性得到提高。實驗結果證明了使用長短跳躍連接的有效性。這一次,作者還提供了一種通過分析網絡中的權重來展示其有效性的方法,而不僅僅是展示實驗結果。儘管這項工作的目的是進行生物醫學圖像分割,但通過觀察網絡內的權重,我們可以更好地理解長短跳躍連接。它發布於2016年DLMIA(醫學圖像分析中的深度學習),引用次數超過100次。
  • 快速回顧:U-Net Family方法匯總
    目的是通過對每個像素進行分類來識別圖像中不同對象的位置和形狀。2 Metrics我們需要一組度量標準來比較不同的模型,通常有BCE,Dice係數和IoU。儘管U-Net也可用於自然圖像的分割,但它在許多生物醫學圖像的圖像分割任務中都得到不錯得效果。Requires fewer training samples深度學習模型的成功訓練需要成千上萬的帶注釋的訓練樣本,但是獲取帶注釋的醫學圖像是巨大的。U-Net可以使用更少的訓練樣本進行端到端訓練。
  • 基於注意力機制改進U-Net的醫學圖像分割算法
    近年來,由於深度學習方法的迅速發展,基於深度學習的圖像分割算法在醫學圖像分割領域取得了顯著的成就。其中依賴於編碼器-解碼器體系結構的U-Net被研究人員廣泛使用。但是U-Net網絡在下採樣的過程中卷積、池化都是局部算子,要獲取全局信息就需要深度編碼器,這樣會引入大量的訓練參數,並且丟失更多圖像的空間信息。而在上採樣過程中使用反卷積、反池化很難進行空間信息的恢復。
  • 深度學習第33講:CNN圖像語義分割和實例分割綜述
    10 篇論文,對兩階段的目標檢測 R-CNN 系列算法和一步走的 yolo 系列算法有了一個全面的了解和概況。比如說圖像有多個人甲、乙、丙,那邊他們的語義分割結果都是人,而實例分割結果卻是不同的對象,具體如下圖第二第三兩幅小圖所示: 目標檢測、語義分割和實例分割      以語義分割和實例分割為代表的圖像分割技術在各領域都有廣泛的應用,例如在無人駕駛和醫學影像分割等方面。
  • 一文概覽主要語義分割網絡:FCN,SegNet,U-Net...
    全卷積網絡(FCNs)可以用於自然圖像的語義分割、多模態醫學圖像分析和多光譜衛星圖像分割。與 AlexNet、VGG、ResNet 等深度分類網絡類似,FCNs 也有大量進行語義分割的深層架構。U-Net 在 EM 數據集上取得了最優異的結果,該數據集只有 30 個密集標註的醫學圖像和其他醫學圖像數據集,U-Net 後來擴展到 3D 版的 3D-U-Net。
  • 醫學圖像語義分割最佳方法的全面比較:UNet和UNet++
    語義分割在生物醫學圖像分析中有著廣泛的應用:x射線、MRI掃描、數字病理、顯微鏡、內窺鏡等。https://grand-challenge.org/challenges上有許多不同的有趣和重要的問題有待探索。從技術角度來看,如果我們考慮語義分割問題,對於N×M×3(假設我們有一個RGB圖像)的圖像,我們希望生成對應的映射N×M×k(其中k是類的數量)。
  • 來聊聊U-Net及其變種
    醫學圖像特點1.圖像語義較為簡單、結構較為固定。我們做腦的,就用腦CT和腦MRI,做胸片的只用胸片CT,做眼底的只用眼底OCT,都是一個固定的器官的成像,而不是全身的。由於器官本身結構固定和語義信息沒有特別豐富,所以高級語義信息和低級特徵都顯得很重要(UNet的skip connection和U型結構就派上了用場)。
  • 人工智慧・圖像分割(5)
    在《人工智慧・物體識別模型(後篇)》一文中,寶寶們曾經看到過,U-Net所在的論文是《U-Net: Convolutional Networks for Biomedical Image Segmentation》,它和Mask RCNN一樣,都能解決圖像分割問題。
  • 圖像語義分割入門:FCN/U-Net網絡解析
    (Semantic Segmentation)是圖像處理和是機器視覺技術中關於圖像理解的重要一環,也是 AI 領域中一個重要的分支。這些抽象特徵對物體的大小、位置和方向等敏感性更低,從而有助於分類性能的提高。這些抽象的特徵對分類很有幫助,可以很好地判斷出一幅圖像中包含什麼類別的物體。圖像分類是圖像級別的!
  • 常用圖像閾值分割算法
    下面對各種方法進行混合展示:第一類:全局閾值處理圖像閾值化分割是一種傳統的最常用的圖像分割方法,因其實現簡單、計算量小、性能較穩定而成為圖像分割中最基本和應用最廣泛的分割技術。它特別適用於目標和背景佔據不同灰度級範圍的圖像。難點在於如何選擇一個合適的閾值實現較好的分割。
  • U-Net:基於小樣本的高精度醫學影像語義分割模型
    https://zhuanlan.zhihu.com/p/68147968原論文地址:U-Net: Convolutional Networks for Biomedical Image SegmentationPytorch 實現:github.com/milesial/Pyt一、U-Net 概述U-Net 作為一個圖像語義分割網絡
  • 實戰 | 基於SegNet和U-Net的遙感圖像語義分割
    這兩周數據挖掘課期末project我們組選的課題也是遙感圖像的語義分割,所以剛好又把前段時間做的成果重新整理和加強了一下,故寫了這篇文章,記錄一下用深度學習做遙感圖像語義分割的完整流程以及一些好的思路和技巧。
  • TLU-Net:表面缺陷自動檢測的深度學習方法
    近年來研究了幾種基於機器學習的自動視覺檢測(AVI)方法。然而,由於訓練時間和AVI方法的不準確性,大多數鋼鐵製造行業仍然使用人工目視檢查。自動鋼缺陷檢測方法在成本更低和更快的質量控制和反饋方面是有用的。但是,為分割和分類準備帶注釋的訓練數據可能是一個昂貴的過程。在這項工作中,我們建議使用基於遷移學習的U-Net (tu - net)框架來檢測鋼表面缺陷。
  • FCN、Unet、Unet++:醫學圖像分割網絡一覽
    Unet在醫學圖像上的適用與CNN分割算法的簡要總結一、相關知識點解釋語義分割(Semantic Segmentation):就是對一張圖像上的所有像素點進行分類。(eg: FCN/Unet/Unet++/...)
  • 第一批在 SQUAD 2.0 上刷榜的 U-NET 模型,它們有何高明之處?
    幸運的是,前四名的表現並沒有太大的不同,所以我們可以看看一些高性能的想法。最佳方案採用了基於 U-net 的架構,相關的論文連結如下:https://arxiv.org/abs/1810.06638  。本文也將從這裡展開。「U-net 背後的思想是什麼?」
  • 總結|圖像分割5大經典方法
    主要分割方法有:基於閾值的分割方法https://www.cnblogs.com/wangduo/p/5556903.html閾值法的基本思想是基於圖像的灰度特徵來計算一個或多個灰度閾值,並將圖像中每個像素的灰度值與閾值相比較,最後將像素根據比較結果分到合適的類別中。
  • 如何讀文獻(U-Net)
    U-Net最早在2015年MICCAI會議上提出,用於解決生物醫學圖像分割問題。目前已達到19345次引用。其採用的Encoder-Decoder結構和Skip Connection是一種非常經典的設計方法,由於其架構之優美以及顯著的效果,目前在超解析度、去噪等領域都大放異彩。但是,其對醫學影像分割領域的影響是最大的,以至於大多數基於CNN的醫學影像分割任務都是基於類似的架構。
  • 大盤點 | 2020年5篇圖像分割算法最佳綜述
    最近,由於深度學習模型在各種視覺應用中的成功,已經有大量旨在利用深度學習模型開發圖像分割方法的工作。本文提供了對文獻的全面回顧,涵蓋了語義和實例級分割的眾多開創性作品,包括全卷積像素標記網絡,編碼器-解碼器體系結構,多尺度以及基於金字塔的方法,遞歸網絡,視覺注意模型和對抗環境中的生成模型。