加入極市專業CV交流群,與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度 等名校名企視覺開發者互動交流!
同時提供每月大咖直播分享、真實項目需求對接、乾貨資訊匯總,行業技術交流。關注 極市平臺 公眾號 ,回復 加群,立刻申請入群~
這是一篇關於圖像分割損失函數的總結,具體包括:Binary Cross Entropy
Weighted Cross Entropy
Balanced Cross Entropy
Dice Loss
Focal loss
Tversky loss
Focal Tversky loss
log-cosh dice loss (本文提出的新損失函數)
論文地址:https://arxiv.org/pdf/2006.14822.pdfhttps://github.com/shruti-jadon/Semantic-Segmentation-Loss-Functions項目推薦:https://github.com/JunMa11/SegLoss圖像分割一直是一個活躍的研究領域,因為它有可能修復醫療領域的漏洞,並幫助大眾。在過去的5年裡,各種論文提出了不同的目標損失函數,用於不同的情況下,如偏差數據,稀疏分割等。在本文中,總結了大多數廣泛用於圖像分割的損失函數,並列出了它們可以幫助模型更快速、更好的收斂模型的情況。此外,本文還介紹了一種新的log-cosh dice損失函數,並將其在NBFS skull-stripping數據集上與廣泛使用的損失函數進行了性能比較。某些損失函數在所有數據集上都表現良好,在未知分布數據集上可以作為一個很好的選擇。簡介深度學習徹底改變了從軟體到製造業的各個行業。深度學習在醫學界的應用也十分廣泛,例如使用U-Net進行腫瘤分割、使用SegNet進行癌症檢測等。在這些應用中,圖像分割是至關重要的,分割後的圖像除了告訴我們存在某種疾病外,還展示了它到底存在於何處,這為實現自動檢測CT掃描中的病變等功能提供基礎保障。#二值交叉熵,這裡輸入要經過sigmoid處理
import torch
import torch.nn as nn
import torch.nn.functional as F
nn.BCELoss(F.sigmoid(input), target)
#多分類交叉熵, 用這個 loss 前面不需要加 Softmax 層
nn.CrossEntropyLoss(input, target)2、Weighted Binary Cross-Entropy加權交叉熵損失函數加權交叉熵損失函數只是在交叉熵Loss的基礎上為每一個類別添加了一個權重參數為正樣本加權。設置 >1,減少假陰性;設置 <1,減少假陽性。這樣相比於原始的交叉熵Loss,在樣本數量不均衡的情況下可以獲得更好的效果。class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
"""
Network has to have NO NONLINEARITY!
"""
def __init__(self, weight=None):
super(WeightedCrossEntropyLoss, self).__init__()
self.weight = weight
def forward(self, inp, target):
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
target = target.view(-1,)
wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)
return wce_loss(inp, target)3、Balanced Cross-Entropy平衡交叉熵損失函數與加權交叉熵損失函數類似,但平衡交叉熵損失函數對負樣本也進行加權。Focal loss是在目標檢測領域提出來的。其目的是關注難例(也就是給難分類的樣本較大的權重)。對於正樣本,使預測概率大的樣本(簡單樣本)得到的loss變小,而預測概率小的樣本(難例)loss變得大,從而加強對難例的關注度。但引入了額外參數,增加了調參難度。class FocalLoss(nn.Module):
"""
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
Focal_Loss= -1*alpha*(1-pt)*log(pt)
:param num_class:
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param smooth: (float,double) smooth value when cross entropy
:param balance_index: (int) balance class index, should be specific when alpha is float
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
"""
def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
super(FocalLoss, self).__init__()
self.apply_nonlin = apply_nonlin
self.alpha = alpha
self.gamma = gamma
self.balance_index = balance_index
self.smooth = smooth
self.size_average = size_average
if self.smooth is not None:
if self.smooth < 0 or self.smooth > 1.0:
raise ValueError('smooth value should be in [0,1]')
def forward(self, logit, target):
if self.apply_nonlin is not None:
logit = self.apply_nonlin(logit)
num_class = logit.shape[1]
if logit.dim() > 2:
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
logit = logit.view(logit.size(0), logit.size(1), -1)
logit = logit.permute(0, 2, 1).contiguous()
logit = logit.view(-1, logit.size(-1))
target = torch.squeeze(target, 1)
target = target.view(-1, 1)
# print(logit.shape, target.shape)
#
alpha = self.alpha
if alpha is None:
alpha = torch.ones(num_class, 1)
elif isinstance(alpha, (list, np.ndarray)):
assert len(alpha) == num_class
alpha = torch.FloatTensor(alpha).view(num_class, 1)
alpha = alpha / alpha.sum()
elif isinstance(alpha, float):
alpha = torch.ones(num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[self.balance_index] = self.alpha
else:
raise TypeError('Not support alpha type')
if alpha.device != logit.device:
alpha = alpha.to(logit.device)
idx = target.cpu().long()
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
if one_hot_key.device != logit.device:
one_hot_key = one_hot_key.to(logit.device)
if self.smooth:
one_hot_key = torch.clamp(
one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + self.smooth
logpt = pt.log()
gamma = self.gamma
alpha = alpha[idx]
alpha = torch.squeeze(alpha)
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss5、Distance map derived loss penalty term距離圖得出的損失懲罰項可以將距離圖定義為ground truth與預測圖之間的距離(歐幾裡得距離、絕對距離等)。合併映射的方法有2種,一種是創建神經網絡架構,在該算法中有一個用於分割的重建head,或者將其引入損失函數。遵循相同的理論,可以從GT mask得出的距離圖,並創建了一個基於懲罰的自定義損失函數。使用這種方法,可以很容易地將網絡引導到難以分割的邊界區域。損失函數定義為:class DisPenalizedCE(torch.nn.Module):
"""
Only for binary 3D segmentation
Network has to have NO NONLINEARITY!
"""
def forward(self, inp, target):
# print(inp.shape, target.shape) # (batch, 2, xyz), (batch, 2, xyz)
# compute distance map of ground truth
with torch.no_grad():
dist = compute_edts_forPenalizedLoss(target.cpu().numpy()>0.5) + 1.0
dist = torch.from_numpy(dist)
if dist.device != inp.device:
dist = dist.to(inp.device).type(torch.float32)
dist = dist.view(-1,)
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
log_sm = torch.nn.LogSoftmax(dim=1)
inp_logs = log_sm(inp)
target = target.view(-1,)
# loss = nll_loss(inp_logs, target)
loss = -inp_logs[range(target.shape[0]), target]
# print(loss.type(), dist.type())
weighted_loss = loss*dist
return loss.mean()Region-based lossDice係數是計算機視覺界廣泛使用的度量標準,用於計算兩個圖像之間的相似度。在2016年的時候,它也被改編為損失函數,稱為Dice損失。Dice係數:是用來度量集合相似度的度量函數,通常用於計算兩個樣本之間的像素之間的相似度,公式如下: 分子中之所以有一個係數2是因為分母中有重複計算和 的原因,的取值範圍是 。而針對分割任務來說, 表示的就是Ground Truth分割圖像,而Y代表的就是預測的分割圖像。此處,在分子和分母中添加1以確保函數在諸如y = 0的極端情況下的確定性。Dice Loss使用與樣本極度不均衡的情況,如果一般情況下使用Dice Loss會回反向傳播有不利的影響,使得訓練不穩定。 def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes:
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = sum_tensor(tp, axes, keepdim=False)
fp = sum_tensor(fp, axes, keepdim=False)
fn = sum_tensor(fn, axes, keepdim=False)
return tp, fp, fn
class SoftDiceLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False):
"""
paper: https://arxiv.org/pdf/1606.04797.pdf
"""
super(SoftDiceLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
if not self.do_bg:
if self.batch_dice:
dc = dc[1:]
else:
dc = dc[:, 1:]
dc = dc.mean()
return -dc論文地址為:https://arxiv.org/pdf/1706.05721.pdf 。Tversky係數是Dice係數和 Jaccard 係數的一種推廣。當設置α=β=0.5,此時Tversky係數就是Dice係數。而當設置α=β=1時,此時Tversky係數就是Jaccard係數。α和β分別控制假陰性和假陽性。通過調整α和β,可以控制假陽性和假陰性之間的平衡。class TverskyLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False):
"""
paper: https://arxiv.org/pdf/1706.05721.pdf
"""
super(TverskyLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.alpha = 0.3
self.beta = 0.7
def forward(self, x, y, loss_mask=None):
shp_x = x.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)
if not self.do_bg:
if self.batch_dice:
tversky = tversky[1:]
else:
tversky = tversky[:, 1:]
tversky = tversky.mean()
return -tversky
與「Focal loss」相似,後者著重於通過降低易用/常見損失的權重來說明困難的例子。Focal Tversky Loss還嘗試藉助γ係數來學習諸如在ROI(感興趣區域)較小的情況下的困難示例,如下所示:class FocalTversky_loss(nn.Module):
"""
paper: https://arxiv.org/pdf/1810.07842.pdf
author code: https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65
"""
def __init__(self, tversky_kwargs, gamma=0.75):
super(FocalTversky_loss, self).__init__()
self.gamma = gamma
self.tversky = TverskyLoss(**tversky_kwargs)
def forward(self, net_output, target):
tversky_loss = 1 + self.tversky(net_output, target) # = 1-tversky(net_output, target)
focal_tversky = torch.pow(tversky_loss, self.gamma)
return focal_tversky
4、Sensitivity Specificity Loss而Sensitivity Specificity Loss為:其中左邊為病灶像素的錯誤率即,1−Sensitivity,而不是正確率,所以設置λ 為0.05。其中是為了得到平滑的梯度。class SSLoss(nn.Module):
def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
square=False):
"""
Sensitivity-Specifity loss
paper: http://www.rogertam.ca/Brosch_MICCAI_2015.pdf
tf code: https://github.com/NifTK/NiftyNet/blob/df0f86733357fdc92bbc191c8fec0dcf49aa5499/niftynet/layer/loss_segmentation.py#L392
"""
super(SSLoss, self).__init__()
self.square = square
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.r = 0.1 # weight parameter in SS paper
def forward(self, net_output, gt, loss_mask=None):
shp_x = net_output.shape
shp_y = gt.shape
# class_num = shp_x[1]
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.apply_nonlin is not None:
softmax_output = self.apply_nonlin(net_output)
# no object value
bg_onehot = 1 - y_onehot
squared_error = (y_onehot - softmax_output)**2
specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)
sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)
ss = self.r * specificity_part + (1-self.r) * sensitivity_part
if not self.do_bg:
if self.batch_dice:
ss = ss[1:]
else:
ss = ss[:, 1:]
ss = ss.mean()
return ss5、Log-Cosh Dice Loss(本文提出的損失函數)Dice係數是一種用於評估分割輸出的度量標準。它也已修改為損失函數,因為它可以實現分割目標的數學表示。但是由於其非凸性,它多次都無法獲得最佳結果。Lovsz-softmax損失旨在通過添加使用Lovsz擴展的平滑來解決非凸損失函數的問題。同時,Log-Cosh方法已廣泛用於基於回歸的問題中,以平滑曲線。將Cosh(x)函數和Log(x)函數合併,可以得到Log-Cosh Dice Loss:def log_cosh_dice_loss(self, y_true, y_pred):
x = self.dice_loss(y_true, y_pred)
return tf.math.log((torch.exp(x) + torch.exp(-x)) / 2.0)
顧名思義,Shape-aware Loss考慮了形狀。通常,所有損失函數都在像素級起作用,Shape-aware Loss會計算平均點到曲線的歐幾裡得距離,即預測分割到ground truth的曲線周圍點之間的歐式距離,並將其用作交叉熵損失函數的係數,具體定義如下:(CE指交叉熵損失函數)class DistBinaryDiceLoss(nn.Module):
"""
Distance map penalized Dice loss
Motivated by: https://openreview.net/forum?id=B1eIcvS45V
Distance Map Loss Penalty Term for Semantic Segmentation
"""
def __init__(self, smooth=1e-5):
super(DistBinaryDiceLoss, self).__init__()
self.smooth = smooth
def forward(self, net_output, gt):
"""
net_output: (batch_size, 2, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
"""
net_output = softmax_helper(net_output)
# one hot code for gt
with torch.no_grad():
if len(net_output.shape) != len(gt.shape):
gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(net_output.shape)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
gt_temp = gt[:,0, ...].type(torch.float32)
with torch.no_grad():
dist = compute_edts_forPenalizedLoss(gt_temp.cpu().numpy()>0.5) + 1.0
# print('dist.shape: ', dist.shape)
dist = torch.from_numpy(dist)
if dist.device != net_output.device:
dist = dist.to(net_output.device).type(torch.float32)
tp = net_output * y_onehot
tp = torch.sum(tp[:,1,...] * dist, (1,2,3))
dc = (2 * tp + self.smooth) / (torch.sum(net_output[:,1,...], (1,2,3)) + torch.sum(y_onehot[:,1,...], (1,2,3)) + self.smooth)
dc = dc.mean()
return -dc
2、Hausdorff Distance LossHausdorff Distance Loss(HD)是分割方法用來跟蹤模型性能的度量。它定義為:任何分割模型的目的都是為了最大化Hausdorff距離,但是由於其非凸性,因此並未廣泛用作損失函數。有研究者提出了基於Hausdorff距離的損失函數的3個變量,它們都結合了度量用例,並確保損失函數易於處理。class HDDTBinaryLoss(nn.Module):
def __init__(self):
"""
compute haudorff loss for binary segmentation
https://arxiv.org/pdf/1904.10030v1.pdf
"""
super(HDDTBinaryLoss, self).__init__()
def forward(self, net_output, target):
"""
net_output: (batch_size, 2, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
"""
net_output = softmax_helper(net_output)
pc = net_output[:, 1, ...].type(torch.float32)
gt = target[:,0, ...].type(torch.float32)
with torch.no_grad():
pc_dist = compute_edts_forhdloss(pc.cpu().numpy()>0.5)
gt_dist = compute_edts_forhdloss(gt.cpu().numpy()>0.5)
# print('pc_dist.shape: ', pc_dist.shape)
pred_error = (gt - pc)**2
dist = pc_dist**2 + gt_dist**2 # \alpha=2 in eq(8)
dist = torch.from_numpy(dist)
if dist.device != pred_error.device:
dist = dist.to(pred_error.device).type(torch.float32)
multipled = torch.einsum("bxyz,bxyz->bxyz", pred_error, dist)
hd_loss = multipled.mean()
return hd_lossCompounded loss1、Exponential Logarithmic Loss指數對數損失函數集中於使用骰子損失和交叉熵損失的組合公式來預測不那麼精確的結構。對骰子損失和熵損失進行指數和對數轉換,以合併更精細的分割邊界和準確的數據分布的好處。它定義為:組合損失定義為Dice loss和修正的交叉熵的加權和。它試圖利用Dice損失解決類不平衡問題的靈活性,同時使用交叉熵進行曲線平滑。定義為:(DL指Dice Loss)實驗與結果數據集:NBFS Skull Stripping Dataset[1] https://blog.csdn.net/m0_37477175/article/details/83004746[2] https://zhuanlan.zhihu.com/p/89194726推薦閱讀
添加極市小助手微信(ID : cv-mart),備註:研究方向-姓名-學校/公司-城市(如:目標檢測-小極-北大-深圳),即可申請加入極市技術交流群,更有每月大咖直播分享、真實項目需求對接、求職內推、算法競賽、乾貨資訊匯總、行業技術交流,一起來讓思想之光照的更遠吧~
△長按添加極市小助手
△長按關注極市平臺,獲取最新CV乾貨
覺得有用麻煩給個在看啦~