圖像分割一直是一個活躍的研究領域,因為它有可能修復醫療領域的漏洞,並幫助大眾。在過去的5年裡,各種論文提出了不同的目標損失函數,用於不同的情況下,如偏差數據,稀疏分割等。在本文中,總結了大多數廣泛用於圖像分割的損失函數,並列出了它們可以幫助模型更快速、更好的收斂模型的情況。此外,本文還介紹了一種新的log-cosh dice損失函數,並將其在NBFS skull-stripping數據集上與廣泛使用的損失函數進行了性能比較。某些損失函數在所有數據集上都表現良好,在未知分布數據集上可以作為一個很好的選擇。
深度學習徹底改變了從軟體到製造業的各個行業。深度學習在醫學界的應用也十分廣泛,例如使用U-Net進行腫瘤分割、使用SegNet進行癌症檢測等。在這些應用中,圖像分割是至關重要的,分割後的圖像除了告訴我們存在某種疾病外,還展示了它到底存在於何處,這為實現自動檢測CT掃描中的病變等功能提供基礎保障。
圖像分割可以定義為像素級別的分類任務。圖像由各種像素組成,這些像素組合在一起定義了圖像中的不同元素,因此將這些像素分類為一類元素的方法稱為語義圖像分割。在設計基於複雜圖像分割的深度學習架構時,通常會遇到了一個至關重要的選擇,即選擇哪個損失/目標函數,因為它們會激發算法的學習過程。損失函數的選擇對於任何架構學習正確的目標都是至關重要的,因此自2012年以來,各種研究人員開始設計針對特定領域的損失函數,以為其數據集獲得更好的結果。
在本文中,總結了15種基於圖像分割的損失函數。被證明可以在不同領域提供最新技術成果。這些損失函數可大致分為4類:基於分布的損失函數,基於區域的損失函數,基於邊界的損失函數和基於複合的損失函數( Distribution-based,Region-based, Boundary-based, and Compounded)。
本文還討論了確定哪種目標/損失函數在場景中可能有用的條件。除此之外,還提出了一種新的log-cosh dice損失函數用於圖像語義分割。為了展示其效率,還比較了NBFS頭骨剝離數據集上所有損失函數的性能。
1. Binary Cross-Entropy:二進位交叉熵損失函數
交叉熵定義為對給定隨機變量或事件集的兩個概率分布之間的差異的度量。它被廣泛用於分類任務,並且由於分割是像素級分類,因此效果很好。在多分類任務中,經常採用 softmax 激活函數+交叉熵損失函數,因為交叉熵描述了兩個概率分布的差異,然而神經網絡輸出的是向量,並不是概率分布的形式。所以需要 softmax激活函數將一個向量進行「歸一化」成概率分布的形式,再採用交叉熵損失函數計算 loss。
交叉熵損失函數的具體表達為:
其中, 表示樣本i的label,正類為1,負類為0。表示預測值。如果是計算 N 個樣本的總的損失函數,只要將 N 個 Loss 疊加起來就可以了:
交叉熵損失函數可以用在大多數語義分割場景中,但它有一個明顯的缺點:當圖像分割任務只需要分割前景和背景兩種情況。當前景像素的數量遠遠小於背景像素的數量時,即
#二值交叉熵,這裡輸入要經過sigmoid處理
import torch
import torch.nn as nn
import torch.nn.functional as F
nn.BCELoss(F.sigmoid(input), target)
#多分類交叉熵, 用這個 loss 前面不需要加 Softmax 層
nn.CrossEntropyLoss(input, target)2、Weighted Binary Cross-Entropy加權交叉熵損失函數
加權交叉熵損失函數只是在交叉熵Loss的基礎上為每一個類別添加了一個權重參數為正樣本加權。設置 >1,減少假陰性;設置 <1,減少假陽性。這樣相比於原始的交叉熵Loss,在樣本數量不均衡的情況下可以獲得更好的效果。
class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
"""
Network has to have NO NONLINEARITY!
"""
def __init__(self, weight=None):
super(WeightedCrossEntropyLoss, self).__init__()
self.weight = weight
def forward(self, inp, target):
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
target = target.view(-1,)
wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)
return wce_loss(inp, target)3、Balanced Cross-Entropy平衡交叉熵損失函數
與加權交叉熵損失函數類似,但平衡交叉熵損失函數對負樣本也進行加權。
4、Focal Loss
Focal loss是在目標檢測領域提出來的。其目的是關注難例(也就是給難分類的樣本較大的權重)。對於正樣本,使預測概率大的樣本(簡單樣本)得到的loss變小,而預測概率小的樣本(難例)loss變得大,從而加強對難例的關注度。但引入了額外參數,增加了調參難度。
class FocalLoss(nn.Module):
"""
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
Focal_Loss= -1*alpha*(1-pt)*log(pt)
:param num_class:
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param smooth: (float,double) smooth value when cross entropy
:param balance_index: (int) balance class index, should be specific when alpha is float
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
"""
def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
super(FocalLoss, self).__init__()
self.apply_nonlin = apply_nonlin
self.alpha = alpha
self.gamma = gamma
self.balance_index = balance_index
self.smooth = smooth
self.size_average = size_average
if self.smooth is not None:
if self.smooth < 0 or self.smooth > 1.0:
raise ValueError('smooth value should be in [0,1]')
def forward(self, logit, target):
if self.apply_nonlin is not None:
logit = self.apply_nonlin(logit)
num_class = logit.shape[1]
if logit.dim() > 2:
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
logit = logit.view(logit.size(0), logit.size(1), -1)
logit = logit.permute(0, 2, 1).contiguous()
logit = logit.view(-1, logit.size(-1))
target = torch.squeeze(target, 1)
target = target.view(-1, 1)
# print(logit.shape, target.shape)
#
alpha = self.alpha
if alpha is None:
alpha = torch.ones(num_class, 1)
elif isinstance(alpha, (list, np.ndarray)):
assert len(alpha) == num_class
alpha = torch.FloatTensor(alpha).view(num_class, 1)
alpha = alpha / alpha.sum()
elif isinstance(alpha, float):
alpha = torch.ones(num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[self.balance_index] = self.alpha
else:
raise TypeError('Not support alpha type')
if alpha.device != logit.device:
alpha = alpha.to(logit.device)
idx = target.cpu().long()
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
if one_hot_key.device != logit.device:
one_hot_key = one_hot_key.to(logit.device)
if self.smooth:
one_hot_key = torch.clamp(
one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + self.smooth
logpt = pt.log()
gamma = self.gamma
alpha = alpha[idx]
alpha = torch.squeeze(alpha)
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss5、Distance map derived loss penalty term距離圖得出的損失懲罰項
可以將距離圖定義為ground truth與預測圖之間的距離(歐幾裡得距離、絕對距離等)。合併映射的方法有2種,一種是創建神經網絡架構,在該算法中有一個用於分割的重建head,或者將其引入損失函數。遵循相同的理論,可以從GT mask得出的距離圖,並創建了一個基於懲罰的自定義損失函數。使用這種方法,可以很容易地將網絡引導到難以分割的邊界區域。損失函數定義為:
class DisPenalizedCE(torch.nn.Module):
"""
Only for binary 3D segmentation
Network has to have NO NONLINEARITY!
"""
def forward(self, inp, target):
# print(inp.shape, target.shape) # (batch, 2, xyz), (batch, 2, xyz)
# compute distance map of ground truth
with torch.no_grad():
dist = compute_edts_forPenalizedLoss(target.cpu().numpy()>0.5) + 1.0
dist = torch.from_numpy(dist)
if dist.device != inp.device:
dist = dist.to(inp.device).type(torch.float32)
dist = dist.view(-1,)
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
log_sm = torch.nn.LogSoftmax(dim=1)
inp_logs = log_sm(inp)
target = target.view(-1,)
# loss = nll_loss(inp_logs, target)
loss = -inp_logs[range(target.shape[0]), target]
# print(loss.type(), dist.type())
weighted_loss = loss*dist
return loss.mean()1、Dice Loss
Dice係數是計算機視覺界廣泛使用的度量標準,用於計算兩個圖像之間的相似度。在2016年的時候,它也被改編為損失函數,稱為Dice損失。
Dice係數:是用來度量集合相似度的度量函數,通常用於計算兩個樣本之間的像素之間的相似度,公式如下:
或 分子中之所以有一個係數2是因為分母中有重複計算
和 的原因,的取值範圍是 。而針對分割任務來說, 表示的就是Ground Truth分割圖像,而Y代表的就是預測的分割圖像。 Dice Loss:
此處,在分子和分母中添加1以確保函數在諸如y = 0的極端情況下的確定性。Dice Loss使用與樣本極度不均衡的情況,如果一般情況下使用Dice Loss會回反向傳播有不利的影響,使得訓練不穩定。
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes:
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
class SoftDiceLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False):
"""
paper: https://arxiv.org/pdf/1606.04797.pdf
"""
super(SoftDiceLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
if not self.do_bg:
if self.batch_dice:
dc = dc[1:]
else:
dc = dc[:, 1:]
dc = dc.mean()
return -dc2、Tversky Loss
論文地址為:https://arxiv.org/pdf/1706.05721.pdf 。Tversky係數是Dice係數和 Jaccard 係數的一種推廣。當設置α=β=0.5,此時Tversky係數就是Dice係數。而當設置α=β=1時,此時Tversky係數就是Jaccard係數。α和β分別控制假陰性和假陽性。通過調整α和β,可以控制假陽性和假陰性之間的平衡。
class TverskyLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False):
"""
paper: https://arxiv.org/pdf/1706.05721.pdf
"""
super(TverskyLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.alpha = 0.3
self.beta = 0.7
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)
if not self.do_bg:
if self.batch_dice:
tversky = tversky[1:]
else:
tversky = tversky[:, 1:]
tversky = tversky.mean()
return -tversky3、Focal Tversky Loss
與「Focal loss」相似,後者著重於通過降低易用/常見損失的權重來說明困難的例子。Focal Tversky Loss還嘗試藉助γ係數來學習諸如在ROI(感興趣區域)較小的情況下的困難示例,如下所示:
class FocalTversky_loss(nn.Module):
"""
paper: https://arxiv.org/pdf/1810.07842.pdf
author code: https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65
"""
def __init__(self, tversky_kwargs, gamma=0.75):
super(FocalTversky_loss, self).__init__()
self.gamma = gamma
self.tversky = TverskyLoss(**tversky_kwargs)
def forward(self, net_output, target):
tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target)
focal_tversky = torch.pow(tversky_loss, self.gamma)
return focal_tversky4、Sensitivity Specificity Loss
首先敏感性就是召回率,檢測出確實有病的能力:
特異性,檢測出確實沒病的能力:
而Sensitivity Specificity Loss為:
其中左邊為病灶像素的錯誤率即,1−Sensitivity,而不是正確率,所以設置λ 為0.05。其中是為了得到平滑的梯度。
class SSLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False):
"""
Sensitivity-Specifity loss
paper: http://www.rogertam.ca/Brosch_MICCAI_2015.pdf
tf code: https://github.com/NifTK/NiftyNet/blob/df0f86733357fdc92bbc191c8fec0dcf49aa5499/niftynet/layer/loss_segmentation.py#L392
"""
super(SSLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.r = 0.1 # weight parameter in SS paper
def forward(self, net_output, gt, loss_mask=None):
shp_x = net_output.shape
shp_y = gt.shape
# class_num = shp_x[1]
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
softmax_output = self.apply_nonlin(net_output)
# no object value
bg_onehot = 1 - y_onehot
squared_error = (y_onehot - softmax_output)**2
specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)
sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)
ss = self.r * specificity_part + (1-self.r) * sensitivity_part
if not self.do_bg:
if self.batch_dice:
ss = ss[1:]
else:
ss = ss[:, 1:]
ss = ss.mean()
return ss5、Log-Cosh Dice Loss(本文提出的損失函數)
Dice係數是一種用於評估分割輸出的度量標準。它也已修改為損失函數,因為它可以實現分割目標的數學表示。但是由於其非凸性,它多次都無法獲得最佳結果。Lovsz-softmax損失旨在通過添加使用Lovsz擴展的平滑來解決非凸損失函數的問題。同時,Log-Cosh方法已廣泛用於基於回歸的問題中,以平滑曲線。
將Cosh(x)函數和Log(x)函數合併,可以得到Log-Cosh Dice Loss:
def log_cosh_dice_loss(self, y_true, y_pred):
x = self.dice_loss(y_true, y_pred)
return tf.math.log((torch.exp(x) + torch.exp(-x)) / 2.0)1、Shape-aware Loss
顧名思義,Shape-aware Loss考慮了形狀。通常,所有損失函數都在像素級起作用,Shape-aware Loss會計算平均點到曲線的歐幾裡得距離,即預測分割到ground truth的曲線周圍點之間的歐式距離,並將其用作交叉熵損失函數的係數,具體定義如下:(CE指交叉熵損失函數)
class DistBinaryDiceLoss(nn.Module):
"""
Distance map penalized Dice loss
Motivated by: https://openreview.net/forum?id=B1eIcvS45V
Distance Map Loss Penalty Term for Semantic Segmentation
"""
def __init__(self, smooth=1e-5):
super(DistBinaryDiceLoss, self).__init__()
self.smooth = smooth
def forward(self, net_output, gt):
"""
net_output: (batch_size, 2, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
"""
net_output = softmax_helper(net_output)
# one hot code for gt
with torch.no_grad():
if len(net_output.shape) != len(gt.shape):
gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(net_output.shape)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
gt_temp = gt[:,0, ...].type(torch.float32)
with torch.no_grad():
dist = compute_edts_forPenalizedLoss(gt_temp.cpu().numpy()>0.5) + 1.0
# print('dist.shape: ', dist.shape)
dist = torch.from_numpy(dist)
if dist.device != net_output.device:
dist = dist.to(net_output.device).type(torch.float32)
tp = net_output * y_onehot
tp = torch.sum(tp[:,1,...] * dist, (1,2,3))
dc = (2 * tp + self.smooth) / (torch.sum(net_output[:,1,...], (1,2,3)) + torch.sum(y_onehot[:,1,...], (1,2,3)) + self.smooth)
dc = dc.mean()
return -dc2、Hausdorff Distance Loss
Hausdorff Distance Loss(HD)是分割方法用來跟蹤模型性能的度量。它定義為:
任何分割模型的目的都是為了最大化Hausdorff距離,但是由於其非凸性,因此並未廣泛用作損失函數。有研究者提出了基於Hausdorff距離的損失函數的3個變量,它們都結合了度量用例,並確保損失函數易於處理。
class HDDTBinaryLoss(nn.Module):
def __init__(self):
"""
compute haudorff loss for binary segmentation
https://arxiv.org/pdf/1904.10030v1.pdf
"""
super(HDDTBinaryLoss, self).__init__()
def forward(self, net_output, target):
"""
net_output: (batch_size, 2, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
"""
net_output = softmax_helper(net_output)
pc = net_output[:, 1, ...].type(torch.float32)
gt = target[:,0, ...].type(torch.float32)
with torch.no_grad():
pc_dist = compute_edts_forhdloss(pc.cpu().numpy()>0.5)
gt_dist = compute_edts_forhdloss(gt.cpu().numpy()>0.5)
# print('pc_dist.shape: ', pc_dist.shape)
pred_error = (gt - pc)**2
dist = pc_dist**2 + gt_dist**2 # \alpha=2 in eq(8)
dist = torch.from_numpy(dist)
if dist.device != pred_error.device:
dist = dist.to(pred_error.device).type(torch.float32)
multipled = torch.einsum("bxyz,bxyz->bxyz", pred_error, dist)
hd_loss = multipled.mean()
return hd_loss1、Exponential Logarithmic Loss
指數對數損失函數集中於使用骰子損失和交叉熵損失的組合公式來預測不那麼精確的結構。對骰子損失和熵損失進行指數和對數轉換,以合併更精細的分割邊界和準確的數據分布的好處。它定義為:
2、Combo Loss
組合損失定義為Dice loss和修正的交叉熵的加權和。它試圖利用Dice損失解決類不平衡問題的靈活性,同時使用交叉熵進行曲線平滑。定義為:(DL指Dice Loss)
數據集:NBFS Skull Stripping Dataset
實驗細節:使用了簡單的2D U-Net模型架構
對比實驗
參考
[1] https://blog.csdn.net/m0_37477175/article/details/83004746
[2] https://zhuanlan.zhihu.com/p/89194726
論文下載
在CVer公眾號後臺回覆:語義分割損失,即可下載本論文
重磅!CVer-圖像分割 交流群已成立
掃碼添加CVer助手,可申請加入CVer-圖像分割 微信交流群,目前已滿1500+人,旨在交流語義分割、實例分割、全景分割和醫學圖像分割等。
同時也可申請加入CVer大群和細分方向技術群,細分方向已涵蓋:目標檢測、圖像分割、目標跟蹤、人臉檢測&識別、OCR、姿態估計、超解析度、SLAM、醫療影像、Re-ID、GAN、NAS、深度估計、自動駕駛、強化學習、車道線檢測、模型剪枝&壓縮、去噪、去霧、去雨、風格遷移、遙感圖像、行為識別、視頻理解、圖像融合、圖像檢索、論文投稿&交流、PyTorch和TensorFlow等群。
一定要備註:研究方向+地點+學校/公司+暱稱(如圖像分割+上海+上交+卡卡),根據格式備註,可更快被通過且邀請進群
▲長按加微信群
▲長按關注CVer公眾號
整理不易,請給CVer點讚和在看!