dice学习
∣ X ∣ \\left|X\\right| ∣X∣表示 X X X的元素个数
Sørensen–Dice coefficient
给定两个集合,则dice coefficient
D S C = 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ DSC = \\frac{2\\left|X \\cap Y\\right|}{\\left|X\\right| + \\left|Y\\right|} DSC=∣X∣+∣Y∣2∣X∩Y∣
如果两个集合完全一样,则dice coefficient为1
如果两个集合都是二值(只有0和1),则
D S C = 2 T P 2 T P + F P + F N DSC = \\frac{2TP}{2TP + FP + FN} DSC=2TP+FP+FN2TP
从这也能看出,Dice介于0-1之间
如果是两个向量,则
D S C = 2 p ⋅ q p ⋅ p + q ⋅ q DSC = \\frac{2\\mathbf{p}\\cdot\\mathbf{q}}{\\mathbf{p}\\cdot\\mathbf{p} + \\mathbf{q}\\cdot \\mathbf{q}} DSC=p⋅p+q⋅q2p⋅q
求导
∂ D S C ∂ p = − 2 q ( p ⋅ p + q ⋅ q ) − 2 p ( p ⋅ q ) ( p ⋅ p + q ⋅ q ) 2 \\frac{\\partial DSC}{\\partial \\mathbf{p}}=-2\\frac{\\mathbf{q}\\left(\\mathbf{p}\\cdot\\mathbf{p} + \\mathbf{q}\\cdot \\mathbf{q}\\right)-2\\mathbf{p}\\left(\\mathbf{p}\\cdot\\mathbf{q}\\right)}{\\left(\\mathbf{p}\\cdot\\mathbf{p} + \\mathbf{q}\\cdot \\mathbf{q}\\right)^2} ∂p∂DSC=−2(p⋅p+q⋅q)2q(p⋅p+q⋅q)−2p(p⋅q)
DSC与JAC(IOU)的关系
J A C = ∣ X ∩ Y ∣ ∣ X ∪ Y ∣ = ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ − ∣ X ∩ Y ∣ JAC = \\frac{\\left|X\\cap Y\\right|}{\\left|X\\cup Y\\right|}= \\frac{\\left|X\\cap Y\\right|}{\\left|X\\right| +\\left|Y\\right| -\\left|X\\cap Y\\right|} JAC=∣X∪Y∣∣X∩Y∣=∣X∣+∣Y∣−∣X∩Y∣∣X∩Y∣
因此
J A C = D S C 2 − D S C JAC = \\frac{DSC}{2-DSC} JAC=2−DSCDSC
所以当dice增大,jac也会增大(如果你看到论文里dice上升了,但是jac没有…
Dice loss
D i c e L o s s = 1 − D S C Dice\\ Loss = 1-DSC Dice Loss=1−DSC
代码
这个代码是多个类别的
计算的时候相当于每个类别算一次
def my_dice(pred, target, smooth=1e-5, eps=1e-5):""":param pred: prediction (BCHW/ BCHWD):param target: target(BHW/ BHWD):param smooth: :param eps: prevent divide by zero:return: dice loss"""C = pred.size(1)# B1HW/ B1HWDtarget = target.unsqueeze(1)sh = list(target.shape)sh[1] = Co = torch.zeros(size=sh, dtype=target.dtype, device=target.device)# one-hottarget = o.scatter_(dim=1, index=target.long(), value=1)reduce_axis = list(range(2, len(pred.shape)))intersection = torch.sum(target * pred, dim=reduce_axis)ground_o = torch.sum(target, dim=reduce_axis)pred_o = torch.sum(pred, dim=reduce_axis)denominator = ground_o + pred_oresult = 1.0 - (2.0 * intersection + smooth) / (denominator + eps)result = torch.mean(result)return result
也可以用monai
from monai.losses import DiceLoss
Generalized Dice Loss
G D L = 1 − 2 ∑ l = 1 2 w l ∑ n r l n p l n ∑ l = 1 2 w l ∑ n ( r l n + p l n ) GDL = 1 - 2 \\frac{\\sum_{l=1}^2 w_l \\sum_n r_{ln} p_{ln}}{\\sum_{l=1}^2 w_l \\sum_n \\left(r_{ln} +p_{ln}\\right)} GDL=1−2∑l=12wl∑n(rln+pln)∑l=12wl∑nrlnpln
其中 r r r是真实值, p p p是预测
l l l是类别, n n n是像素索引, w l = 1 ( ∑ n r l n ) 2 w_l = \\frac{1}{\\left(\\sum_{n} r_{ln}\\right)^2} wl=(∑nrln)21
代码
def my_generalized_dice(pred, target, smooth=1e-5, eps=1e-5):""":param pred: prediction (BCHW/ BCHWD):param target: target(BHW/ BHWD):param smooth::param eps: prevent divide by zero:return: dice loss"""C = pred.size(1)# B1HW/ B1HWDtarget = target.unsqueeze(1)sh = list(target.shape)sh[1] = Co = torch.zeros(size=sh, dtype=target.dtype, device=target.device)# one-hottarget = o.scatter_(dim=1, index=target.long(), value=1)reduce_axis = list(range(2, len(pred.shape)))intersection = torch.sum(target * pred, dim=reduce_axis)ground_o = torch.sum(target, dim=reduce_axis)pred_o = torch.sum(pred, dim=reduce_axis)denominator = ground_o + pred_ow = 1 / (ground_o * ground_o)infs = torch.isinf(w)# prevent infw[infs] = 0.0max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)w = w + infs * max_valuesnumer = 2.0 * torch.sum(intersection * w, dim=1) + smoothdenom = torch.sum(denominator * w, dim=1) + epsresult = 1 - numer / denomresult = torch.mean(result)return result
也可以用monai
from monai.losses import GeneralizedDiceLoss
不过感觉monai的有点奇怪,monai是算了每个样本每个类别的,然后一起平均
https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
https://zhuanlan.zhihu.com/p/269592183
Generalised Dice Overlap as a Deep Learning Loss Function for Highly Unbalanced Segmentations