這是一篇關(guān)于圖像分割損失函數(shù)的總結(jié),具體包括:
Binary Cross Entropy
Weighted Cross Entropy
Balanced Cross Entropy
Dice Loss
Focal loss
Tversky loss
Focal Tversky loss
log-cosh dice loss (本文提出的新?lián)p失函數(shù))
代碼地址:https://github.com/shruti-jadon/Semantic-Segmentation-Loss-Functions
項(xiàng)目推薦:https://github.com/JunMa11/SegLoss
圖像分割一直是一個活躍的研究領(lǐng)域,因?yàn)樗锌赡苄迯?fù)醫(yī)療領(lǐng)域的漏洞,并幫助大眾。在過去的5年里,各種論文提出了不同的目標(biāo)損失函數(shù),用于不同的情況下,如偏差數(shù)據(jù),稀疏分割等。在本文中,總結(jié)了大多數(shù)廣泛用于圖像分割的損失函數(shù),并列出了它們可以幫助模型更快速、更好的收斂模型的情況。此外,本文還介紹了一種新的log-cosh dice損失函數(shù),并將其在NBFS skull-stripping數(shù)據(jù)集上與廣泛使用的損失函數(shù)進(jìn)行了性能比較。某些損失函數(shù)在所有數(shù)據(jù)集上都表現(xiàn)良好,在未知分布數(shù)據(jù)集上可以作為一個很好的選擇。
簡介
深度學(xué)習(xí)徹底改變了從軟件到制造業(yè)的各個行業(yè)。深度學(xué)習(xí)在醫(yī)學(xué)界的應(yīng)用也十分廣泛,例如使用U-Net進(jìn)行腫瘤分割、使用SegNet進(jìn)行癌癥檢測等。在這些應(yīng)用中,圖像分割是至關(guān)重要的,分割后的圖像除了告訴我們存在某種疾病外,還展示了它到底存在于何處,這為實(shí)現(xiàn)自動檢測CT掃描中的病變等功能提供基礎(chǔ)保障。
圖像分割可以定義為像素級別的分類任務(wù)。圖像由各種像素組成,這些像素組合在一起定義了圖像中的不同元素,因此將這些像素分類為一類元素的方法稱為語義圖像分割。在設(shè)計基于復(fù)雜圖像分割的深度學(xué)習(xí)架構(gòu)時,通常會遇到了一個至關(guān)重要的選擇,即選擇哪個損失/目標(biāo)函數(shù),因?yàn)樗鼈儠ぐl(fā)算法的學(xué)習(xí)過程。損失函數(shù)的選擇對于任何架構(gòu)學(xué)習(xí)正確的目標(biāo)都是至關(guān)重要的,因此自2012年以來,各種研究人員開始設(shè)計針對特定領(lǐng)域的損失函數(shù),以為其數(shù)據(jù)集獲得更好的結(jié)果。
在本文中,總結(jié)了15種基于圖像分割的損失函數(shù)。被證明可以在不同領(lǐng)域提供最新技術(shù)成果。這些損失函數(shù)可大致分為4類:基于分布的損失函數(shù),基于區(qū)域的損失函數(shù),基于邊界的損失函數(shù)和基于復(fù)合的損失函數(shù)(Distribution-based,Region-based, Boundary-based, and Compounded)。
本文還討論了確定哪種目標(biāo)/損失函數(shù)在場景中可能有用的條件。除此之外,還提出了一種新的log-cosh dice損失函數(shù)用于圖像語義分割。為了展示其效率,還比較了NBFS頭骨剝離數(shù)據(jù)集上所有損失函數(shù)的性能。
Distribution-based loss
1. Binary Cross-Entropy:二進(jìn)制交叉熵?fù)p失函數(shù)
交叉熵定義為對給定隨機(jī)變量或事件集的兩個概率分布之間的差異的度量。它被廣泛用于分類任務(wù),并且由于分割是像素級分類,因此效果很好。在多分類任務(wù)中,經(jīng)常采用 softmax 激活函數(shù)+交叉熵?fù)p失函數(shù),因?yàn)榻徊骒孛枋隽藘蓚€概率分布的差異,然而神經(jīng)網(wǎng)絡(luò)輸出的是向量,并不是概率分布的形式。所以需要 softmax激活函數(shù)將一個向量進(jìn)行“歸一化”成概率分布的形式,再采用交叉熵?fù)p失函數(shù)計算 loss。
交叉熵?fù)p失函數(shù)可以用在大多數(shù)語義分割場景中,但它有一個明顯的缺點(diǎn):當(dāng)圖像分割任務(wù)只需要分割前景和背景兩種情況。當(dāng)前景像素的數(shù)量遠(yuǎn)遠(yuǎn)小于背景像素的數(shù)量時,即的數(shù)量遠(yuǎn)大于的數(shù)量,損失函數(shù)中的成分就會占據(jù)主導(dǎo),使得模型嚴(yán)重偏向背景,導(dǎo)致效果不好。
#二值交叉熵,這里輸入要經(jīng)過sigmoid處理 importtorch importtorch.nnasnn importtorch.nn.functionalasF nn.BCELoss(F.sigmoid(input),target) #多分類交叉熵,用這個loss前面不需要加Softmax層 nn.CrossEntropyLoss(input,target)
2、Weighted Binary Cross-Entropy加權(quán)交叉熵?fù)p失函數(shù)
classWeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss): """ NetworkhastohaveNONONLINEARITY! """ def__init__(self,weight=None): super(WeightedCrossEntropyLoss,self).__init__() self.weight=weight defforward(self,inp,target): target=target.long() num_classes=inp.size()[1] i0=1 i1=2 whilei1
3、Balanced Cross-Entropy平衡交叉熵?fù)p失函數(shù)
與加權(quán)交叉熵?fù)p失函數(shù)類似,但平衡交叉熵?fù)p失函數(shù)對負(fù)樣本也進(jìn)行加權(quán)。
4、Focal Loss
Focal loss是在目標(biāo)檢測領(lǐng)域提出來的。其目的是關(guān)注難例(也就是給難分類的樣本較大的權(quán)重)。對于正樣本,使預(yù)測概率大的樣本(簡單樣本)得到的loss變小,而預(yù)測概率小的樣本(難例)loss變得大,從而加強(qiáng)對難例的關(guān)注度。但引入了額外參數(shù),增加了調(diào)參難度。
classFocalLoss(nn.Module): """ copyfrom:https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py ThisisaimplementationofFocalLosswithsmoothlabelcrossentropysupportedwhichisproposedin 'FocalLossforDenseObjectDetection.(https://arxiv.org/abs/1708.02002)' Focal_Loss=-1*alpha*(1-pt)*log(pt) :paramnum_class: :paramalpha:(tensor)3Dor4Dthescalarfactorforthiscriterion :paramgamma:(float,double)gamma>0reducestherelativelossforwell-classifiedexamples(p>0.5)puttingmore focusonhardmisclassifiedexample :paramsmooth:(float,double)smoothvaluewhencrossentropy :parambalance_index:(int)balanceclassindex,shouldbespecificwhenalphaisfloat :paramsize_average:(bool,optional)Bydefault,thelossesareaveragedovereachlosselementinthebatch. """ def__init__(self,apply_nonlin=None,alpha=None,gamma=2,balance_index=0,smooth=1e-5,size_average=True): super(FocalLoss,self).__init__() self.apply_nonlin=apply_nonlin self.alpha=alpha self.gamma=gamma self.balance_index=balance_index self.smooth=smooth self.size_average=size_average ifself.smoothisnotNone: ifself.smooth0?or?self.smooth?>1.0: raiseValueError('smoothvalueshouldbein[0,1]') defforward(self,logit,target): ifself.apply_nonlinisnotNone: logit=self.apply_nonlin(logit) num_class=logit.shape[1] iflogit.dim()>2: #N,C,d1,d2->N,C,m(m=d1*d2*...) logit=logit.view(logit.size(0),logit.size(1),-1) logit=logit.permute(0,2,1).contiguous() logit=logit.view(-1,logit.size(-1)) target=torch.squeeze(target,1) target=target.view(-1,1) #print(logit.shape,target.shape) # alpha=self.alpha ifalphaisNone: alpha=torch.ones(num_class,1) elifisinstance(alpha,(list,np.ndarray)): assertlen(alpha)==num_class alpha=torch.FloatTensor(alpha).view(num_class,1) alpha=alpha/alpha.sum() elifisinstance(alpha,float): alpha=torch.ones(num_class,1) alpha=alpha*(1-self.alpha) alpha[self.balance_index]=self.alpha else: raiseTypeError('Notsupportalphatype') ifalpha.device!=logit.device: alpha=alpha.to(logit.device) idx=target.cpu().long() one_hot_key=torch.FloatTensor(target.size(0),num_class).zero_() one_hot_key=one_hot_key.scatter_(1,idx,1) ifone_hot_key.device!=logit.device: one_hot_key=one_hot_key.to(logit.device) ifself.smooth: one_hot_key=torch.clamp( one_hot_key,self.smooth/(num_class-1),1.0-self.smooth) pt=(one_hot_key*logit).sum(1)+self.smooth logpt=pt.log() gamma=self.gamma alpha=alpha[idx] alpha=torch.squeeze(alpha) loss=-1*alpha*torch.pow((1-pt),gamma)*logpt ifself.size_average: loss=loss.mean() else: loss=loss.sum() returnloss
5、Distance map derived loss penalty term距離圖得出的損失懲罰項(xiàng)
可以將距離圖定義為ground truth與預(yù)測圖之間的距離(歐幾里得距離、絕對距離等)。合并映射的方法有2種,一種是創(chuàng)建神經(jīng)網(wǎng)絡(luò)架構(gòu),在該算法中有一個用于分割的重建head,或者將其引入損失函數(shù)。遵循相同的理論,可以從GT mask得出的距離圖,并創(chuàng)建了一個基于懲罰的自定義損失函數(shù)。使用這種方法,可以很容易地將網(wǎng)絡(luò)引導(dǎo)到難以分割的邊界區(qū)域。損失函數(shù)定義為:
classDisPenalizedCE(torch.nn.Module): """ Onlyforbinary3Dsegmentation NetworkhastohaveNONONLINEARITY! """ defforward(self,inp,target): #print(inp.shape,target.shape)#(batch,2,xyz),(batch,2,xyz) #computedistancemapofgroundtruth withtorch.no_grad(): dist=compute_edts_forPenalizedLoss(target.cpu().numpy()>0.5)+1.0 dist=torch.from_numpy(dist) ifdist.device!=inp.device: dist=dist.to(inp.device).type(torch.float32) dist=dist.view(-1,) target=target.long() num_classes=inp.size()[1] i0=1 i1=2 whilei1
Region-based loss
1、Dice Loss
Dice系數(shù)是計算機(jī)視覺界廣泛使用的度量標(biāo)準(zhǔn),用于計算兩個圖像之間的相似度。在2016年的時候,它也被改編為損失函數(shù),稱為Dice損失。
defget_tp_fp_fn(net_output,gt,axes=None,mask=None,square=False): """ net_outputmustbe(b,c,x,y(,z))) gtmustbealabelmap(shape(b,1,x,y(,z))ORshape(b,x,y(,z)))oronehotencoding(b,c,x,y(,z)) ifmaskisprovideditmusthaveshape(b,1,x,y(,z))) :paramnet_output: :paramgt: :paramaxes: :parammask:maskmustbe1forvalidpixelsand0forinvalidpixels :paramsquare:ifTruethenfp,tpandfnwillbesquaredbeforesummation :return: """ ifaxesisNone: axes=tuple(range(2,len(net_output.size()))) shp_x=net_output.shape shp_y=gt.shape withtorch.no_grad(): iflen(shp_x)!=len(shp_y): gt=gt.view((shp_y[0],1,*shp_y[1:])) ifall([i==jfori,jinzip(net_output.shape,gt.shape)]): #ifthisisthecasethengtisprobablyalreadyaonehotencoding y_onehot=gt else: gt=gt.long() y_onehot=torch.zeros(shp_x) ifnet_output.device.type=="cuda": y_onehot=y_onehot.cuda(net_output.device.index) y_onehot.scatter_(1,gt,1) tp=net_output*y_onehot fp=net_output*(1-y_onehot) fn=(1-net_output)*y_onehot ifmaskisnotNone: tp=torch.stack(tuple(x_i*mask[:,0]forx_iintorch.unbind(tp,dim=1)),dim=1) fp=torch.stack(tuple(x_i*mask[:,0]forx_iintorch.unbind(fp,dim=1)),dim=1) fn=torch.stack(tuple(x_i*mask[:,0]forx_iintorch.unbind(fn,dim=1)),dim=1) ifsquare: tp=tp**2 fp=fp**2 fn=fn**2 tp=sum_tensor(tp,axes,keepdim=False) fp=sum_tensor(fp,axes,keepdim=False) fn=sum_tensor(fn,axes,keepdim=False) returntp,fp,fn classSoftDiceLoss(nn.Module): def__init__(self,apply_nonlin=None,batch_dice=False,do_bg=True,smooth=1., square=False): """ paper:https://arxiv.org/pdf/1606.04797.pdf """ super(SoftDiceLoss,self).__init__() self.square=square self.do_bg=do_bg self.batch_dice=batch_dice self.apply_nonlin=apply_nonlin self.smooth=smooth defforward(self,x,y,loss_mask=None): shp_x=x.shape ifself.batch_dice: axes=[0]+list(range(2,len(shp_x))) else: axes=list(range(2,len(shp_x))) ifself.apply_nonlinisnotNone: x=self.apply_nonlin(x) tp,fp,fn=get_tp_fp_fn(x,y,axes,loss_mask,self.square) dc=(2*tp+self.smooth)/(2*tp+fp+fn+self.smooth) ifnotself.do_bg: ifself.batch_dice: dc=dc[1:] else: dc=dc[:,1:] dc=dc.mean() return-dc
2、Tversky Loss
Tversky系數(shù)是Dice系數(shù)和 Jaccard 系數(shù)的一種推廣。當(dāng)設(shè)置α=β=0.5,此時Tversky系數(shù)就是Dice系數(shù)。而當(dāng)設(shè)置α=β=1時,此時Tversky系數(shù)就是Jaccard系數(shù)。α和β分別控制假陰性和假陽性。通過調(diào)整α和β,可以控制假陽性和假陰性之間的平衡。
classTverskyLoss(nn.Module): def__init__(self,apply_nonlin=None,batch_dice=False,do_bg=True,smooth=1., square=False): """ paper:https://arxiv.org/pdf/1706.05721.pdf """ super(TverskyLoss,self).__init__() self.square=square self.do_bg=do_bg self.batch_dice=batch_dice self.apply_nonlin=apply_nonlin self.smooth=smooth self.alpha=0.3 self.beta=0.7 defforward(self,x,y,loss_mask=None): shp_x=x.shape ifself.batch_dice: axes=[0]+list(range(2,len(shp_x))) else: axes=list(range(2,len(shp_x))) ifself.apply_nonlinisnotNone: x=self.apply_nonlin(x) tp,fp,fn=get_tp_fp_fn(x,y,axes,loss_mask,self.square) tversky=(tp+self.smooth)/(tp+self.alpha*fp+self.beta*fn+self.smooth) ifnotself.do_bg: ifself.batch_dice: tversky=tversky[1:] else: tversky=tversky[:,1:] tversky=tversky.mean() return-tversky
3、Focal Tversky Loss
與“Focal loss”相似,后者著重于通過降低易用/常見損失的權(quán)重來說明困難的例子。Focal Tversky Loss還嘗試借助γ系數(shù)來學(xué)習(xí)諸如在ROI(感興趣區(qū)域)較小的情況下的困難示例,如下所示:
classFocalTversky_loss(nn.Module): """ paper:https://arxiv.org/pdf/1810.07842.pdf authorcode:https://github.com/nabsabraham/focal-tversky-unet/blob/347d39117c24540400dfe80d106d2fb06d2b99e1/losses.py#L65 """ def__init__(self,tversky_kwargs,gamma=0.75): super(FocalTversky_loss,self).__init__() self.gamma=gamma self.tversky=TverskyLoss(**tversky_kwargs) defforward(self,net_output,target): tversky_loss=1+self.tversky(net_output,target)#=1-tversky(net_output,target) focal_tversky=torch.pow(tversky_loss,self.gamma) returnfocal_tversky
4、Sensitivity Specificity Loss
首先敏感性就是召回率,檢測出確實(shí)有病的能力:
特異性,檢測出確實(shí)沒病的能力:
而Sensitivity Specificity Loss為:
classSSLoss(nn.Module): def__init__(self,apply_nonlin=None,batch_dice=False,do_bg=True,smooth=1., square=False): """ Sensitivity-Specifityloss paper:http://www.rogertam.ca/Brosch_MICCAI_2015.pdf tfcode:https://github.com/NifTK/NiftyNet/blob/df0f86733357fdc92bbc191c8fec0dcf49aa5499/niftynet/layer/loss_segmentation.py#L392 """ super(SSLoss,self).__init__() self.square=square self.do_bg=do_bg self.batch_dice=batch_dice self.apply_nonlin=apply_nonlin self.smooth=smooth self.r=0.1#weightparameterinSSpaper defforward(self,net_output,gt,loss_mask=None): shp_x=net_output.shape shp_y=gt.shape #class_num=shp_x[1] withtorch.no_grad(): iflen(shp_x)!=len(shp_y): gt=gt.view((shp_y[0],1,*shp_y[1:])) ifall([i==jfori,jinzip(net_output.shape,gt.shape)]): #ifthisisthecasethengtisprobablyalreadyaonehotencoding y_onehot=gt else: gt=gt.long() y_onehot=torch.zeros(shp_x) ifnet_output.device.type=="cuda": y_onehot=y_onehot.cuda(net_output.device.index) y_onehot.scatter_(1,gt,1) ifself.batch_dice: axes=[0]+list(range(2,len(shp_x))) else: axes=list(range(2,len(shp_x))) ifself.apply_nonlinisnotNone: softmax_output=self.apply_nonlin(net_output) #noobjectvalue bg_onehot=1-y_onehot squared_error=(y_onehot-softmax_output)**2 specificity_part=sum_tensor(squared_error*y_onehot,axes)/(sum_tensor(y_onehot,axes)+self.smooth) sensitivity_part=sum_tensor(squared_error*bg_onehot,axes)/(sum_tensor(bg_onehot,axes)+self.smooth) ss=self.r*specificity_part+(1-self.r)*sensitivity_part ifnotself.do_bg: ifself.batch_dice: ss=ss[1:] else: ss=ss[:,1:] ss=ss.mean() returnss
5、Log-Cosh Dice Loss(本文提出的損失函數(shù))
Dice系數(shù)是一種用于評估分割輸出的度量標(biāo)準(zhǔn)。它也已修改為損失函數(shù),因?yàn)樗梢詫?shí)現(xiàn)分割目標(biāo)的數(shù)學(xué)表示。但是由于其非凸性,它多次都無法獲得最佳結(jié)果。Lovsz-softmax損失旨在通過添加使用Lovsz擴(kuò)展的平滑來解決非凸損失函數(shù)的問題。同時,Log-Cosh方法已廣泛用于基于回歸的問題中,以平滑曲線。
將Cosh(x)函數(shù)和Log(x)函數(shù)合并,可以得到Log-Cosh Dice Loss:
deflog_cosh_dice_loss(self,y_true,y_pred): x=self.dice_loss(y_true,y_pred) returntf.math.log((torch.exp(x)+torch.exp(-x))/2.0)
Boundary-based loss
1、Shape-aware Loss
顧名思義,Shape-aware Loss考慮了形狀。通常,所有損失函數(shù)都在像素級起作用,Shape-aware Loss會計算平均點(diǎn)到曲線的歐幾里得距離,即預(yù)測分割到ground truth的曲線周圍點(diǎn)之間的歐式距離,并將其用作交叉熵?fù)p失函數(shù)的系數(shù),具體定義如下:(CE指交叉熵?fù)p失函數(shù))
classDistBinaryDiceLoss(nn.Module): """ DistancemappenalizedDiceloss Motivatedby:https://openreview.net/forum?id=B1eIcvS45V DistanceMapLossPenaltyTermforSemanticSegmentation """ def__init__(self,smooth=1e-5): super(DistBinaryDiceLoss,self).__init__() self.smooth=smooth defforward(self,net_output,gt): """ net_output:(batch_size,2,x,y,z) target:groundtruth,shape:(batch_size,1,x,y,z) """ net_output=softmax_helper(net_output) #onehotcodeforgt withtorch.no_grad(): iflen(net_output.shape)!=len(gt.shape): gt=gt.view((gt.shape[0],1,*gt.shape[1:])) ifall([i==jfori,jinzip(net_output.shape,gt.shape)]): #ifthisisthecasethengtisprobablyalreadyaonehotencoding y_onehot=gt else: gt=gt.long() y_onehot=torch.zeros(net_output.shape) ifnet_output.device.type=="cuda": y_onehot=y_onehot.cuda(net_output.device.index) y_onehot.scatter_(1,gt,1) gt_temp=gt[:,0,...].type(torch.float32) withtorch.no_grad(): dist=compute_edts_forPenalizedLoss(gt_temp.cpu().numpy()>0.5)+1.0 #print('dist.shape:',dist.shape) dist=torch.from_numpy(dist) ifdist.device!=net_output.device: dist=dist.to(net_output.device).type(torch.float32) tp=net_output*y_onehot tp=torch.sum(tp[:,1,...]*dist,(1,2,3)) dc=(2*tp+self.smooth)/(torch.sum(net_output[:,1,...],(1,2,3))+torch.sum(y_onehot[:,1,...],(1,2,3))+self.smooth) dc=dc.mean() return-dc
2、Hausdorff Distance Loss
Hausdorff Distance Loss(HD)是分割方法用來跟蹤模型性能的度量。
任何分割模型的目的都是為了最大化Hausdorff距離,但是由于其非凸性,因此并未廣泛用作損失函數(shù)。有研究者提出了基于Hausdorff距離的損失函數(shù)的3個變量,它們都結(jié)合了度量用例,并確保損失函數(shù)易于處理。
classHDDTBinaryLoss(nn.Module): def__init__(self): """ computehaudorfflossforbinarysegmentation https://arxiv.org/pdf/1904.10030v1.pdf """ super(HDDTBinaryLoss,self).__init__() defforward(self,net_output,target): """ net_output:(batch_size,2,x,y,z) target:groundtruth,shape:(batch_size,1,x,y,z) """ net_output=softmax_helper(net_output) pc=net_output[:,1,...].type(torch.float32) gt=target[:,0,...].type(torch.float32) withtorch.no_grad(): pc_dist=compute_edts_forhdloss(pc.cpu().numpy()>0.5) gt_dist=compute_edts_forhdloss(gt.cpu().numpy()>0.5) #print('pc_dist.shape:',pc_dist.shape) pred_error=(gt-pc)**2 dist=pc_dist**2+gt_dist**2#/alpha=2ineq(8) dist=torch.from_numpy(dist) ifdist.device!=pred_error.device: dist=dist.to(pred_error.device).type(torch.float32) multipled=torch.einsum("bxyz,bxyz->bxyz",pred_error,dist) hd_loss=multipled.mean() returnhd_loss
Compounded loss
1、Exponential Logarithmic Loss
指數(shù)對數(shù)損失函數(shù)集中于使用骰子損失和交叉熵?fù)p失的組合公式來預(yù)測不那么精確的結(jié)構(gòu)。對骰子損失和熵?fù)p失進(jìn)行指數(shù)和對數(shù)轉(zhuǎn)換,以合并更精細(xì)的分割邊界和準(zhǔn)確的數(shù)據(jù)分布的好處。它定義為:
2、Combo Loss
組合損失定義為Dice loss和修正的交叉熵的加權(quán)和。它試圖利用Dice損失解決類不平衡問題的靈活性,同時使用交叉熵進(jìn)行曲線平滑。定義為:(DL指Dice Loss)
審核編輯 黃昊宇
-
圖像分割
+關(guān)注
關(guān)注
4文章
182瀏覽量
18003 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5503瀏覽量
121206
發(fā)布評論請先 登錄
相關(guān)推薦
評論