图像分割损失函数

最近参加kaggle比赛,才发现对于图像分割损失函数有各种形式。同时,关于如何实现这些损失函数,尤其是加权的损失函数,之前并没有研究过。但是在实际应用中,应该还是挺常见的,毕竟样本不均衡问题时有发生。好了,废话不多说了, 进入正题。下面的内容均以二分类问题为例。

cross entropy

图像分割任务的本质为对于像素点的分类,通常称为密集预测(dense prediction)。分类问题自然可以使用cross entropy(交叉熵损失函数)。

设真实情况下$\mathbf{P}(Y = 0) = p$,$\mathbf{P}(Y = 1) = 1 - p$。通过 logistic/sigmoid 函数得到的预测$\mathbf{P}(\hat{Y} = 0) = \frac{1}{1 + e^{-x}} = \hat{p}$,$\mathbf{P}(\hat{Y} = 1) = 1 - \frac{1}{1 + e^{-x}} = 1 - \hat{p}$,则交叉熵损失函数CE为

在keras中,对应函数为binary_crossentropy(y_true, y_pred),在TensorFlow中,对应函数为softmax_cross_entropy_with_logits_v2,在Pytorch中,对应的损失函数为torch.nn.BCEWithLogitsLoss()

Weighted cross entropy

Weighted cross entropy是cross entropy的一种变体,具体体现在所有的正例损失前均有一个系数。主要用于类别不平衡的问题,例如当图像中只有10%的正样本,而有90%的负样本的时候,常规的cross entropy不能正常的work。

如果想减少false negatives(漏报),即增加recall,则设置$\beta>1$;若想减少false positives(误报),则增加precision,则设置$\beta<1$。这个可以这么理解:

  • 当$\beta>1$的时候,$p_{i,j} \log\left(\hat{p}_{i,j}\right)$的系数较大,所谓false negatives(漏报),就是指预测错了,预测为了负样本,实际类别为正样本,此时$p_{i,j}=1$,为了使得损失尽可能的小,会导致$\hat{p}_{i,j}$尽可能大,模型更加倾向于尽可能的减少漏报;
  • 当$\beta<1$的时候,$(1-p_{i,j}) \log\left(1 - \hat{p}_{i,j}\right)$的系数较大,所谓false positives(误报),就是指预测错了,预测为了正样本,实际类别为负样本,此时$1-p_{i,j}=1$,为了使得损失尽可能的小,会导致$\hat{p}_{i,j}$尽可能小,模型更加倾向于尽可能的减少误报。

例如,当数据集中含有100个正例,300个负例的时候,Pytorch中的torch.nn.BCEWithLogitsLoss()函数中的pos_weight参数需要为$\frac{300}{100}=3$。此时的loss相当于有关$100\times3=300$个样本。

Balanced cross entropy

该损失函数和WCE基本一致,不同点在于该损失函数对负样本也进行了加权。

上面的公式均是针对每个样本均有一个权重。对于图像分割任务,相当于对所有样本的所有像素点均有一个权重。且该公式中,不管是正样本还是负样本的损失,均要除以$batch_size \times image_size$来得到均值。

除此之外,在遇到类别不均衡的时候,当计算正负样本损失的时候,分别所以各自的总数,然后加权。这样做的好处是,防止正样本数目过少导致求和后除以$batch_size \times image_size$值很小。当正样本的权值为0.25,负样本的权值为0.75的时候,具体公式可以描述如下:

其中,$p_{i,j}$为一个batch内所有样本所有像素点是否为正样本,为正样本为1,不为正样本为0;$n_{i,j}$为一个batch内所有样本所有像素点是否为负样本,为正样本为0,不为正样本为1;$loss_{i,j}$为一个batch内所有样本所有像素点的损失值。

正样本和负样本权重分别为0.25和0.75是针对SIIM-ACR Pneumothorax Segmentation比赛的。在该比赛中,有掩模的样本总数和无掩模的样本总数大概为1:3,也就相当于0.25:0.75。若不进行加权,则正样本和负样本的损失值基本相同,这不符合实际的数据分布,会导致最终可能出现没有掩模的也预测出了掩模的情况。PS:实际使用的时候,效果特别差。具体原因未知。

具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# reference: https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/101429
def criterion_pixel(logit_pixel, truth_pixel):
logit = logit_pixel.view(-1)
truth = truth_pixel.view(-1)
assert(logit.shape==truth.shape)

loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none')
if 0:
loss = loss.mean()
if 1:
pos = (truth>0.5).float()
neg = (truth<0.5).float()
pos_weight = pos.sum().item() + 1e-12
neg_weight = neg.sum().item() + 1e-12
loss = (0.25*pos*loss/pos_weight + 0.75*neg*loss/neg_weight).sum()

return loss

DiceLoss

DICE与IOU很相似,具体两者的区别如下:

从中可以看出,$\text{DC} \geq \text{IoU}$(两者相减得到的式子中分子为$|X| + |Y| - 2|X \cap Y|$,显然分子大于0)。

DICE也可以作为loss使用,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# reference: https://github.com/asanakoy/kaggle_carvana_segmentation
def dice_loss(preds, trues, weight=None, is_average=True):
num = preds.size(0)
preds = preds.view(num, -1)
trues = trues.view(num, -1)
if weight is not None:
w = torch.autograd.Variable(weight).view(num, -1)
preds = preds * w
trues = trues * w
intersection = (preds * trues).sum(1)
# 分母加1是为了保证不为0,分子加1是为了保证最大值为1
scores = (2. * intersection + 1) / (preds.sum(1) + trues.sum(1) + 1)

if is_average:
score = scores.sum() / num
# clamp函数是为了保证值在[0,1]之间,防止下面的DiceLoss出现负值
return torch.clamp(score, 0., 1.)
else:
return scores

class DiceLoss(nn.Module):
"""
"""
def __init__(self, size_average=True):
super().__init__()
self.size_average = size_average

def forward(self, input, target, weight=None):
return 1 - dice_loss(F.sigmoid(input), target, weight=weight, is_average=self.size_average)

这里解释下dice_loss函数内部的加权。preds.size(0)得到batch_size大小,w的大小为batch_size*image_size,则可以得到下式:

其中,$\hat p_{i,j}$为一个batch第$i$个样本第$j$个像素的预测值,而$p_{i,j}$为一个batch第$i$个样本第$j$个像素的真实值。

所以,这里的加权就相当于对一个batch内的所有样本的loss进行加权,和Pytorch中的BCEWithLogitsLoss中的weight参数含义一致。对于图像分割任务,相当于对一个batch内的所有样本的所有像素点进行加权。

一方面,这样的加权方式不经常使用,因为我们经常会遇到正样本和负样本比例失衡问题,对于所有样本的所有像素点均要设置一个权值,在实现上不如直接设置正样本和负样本的权值方便,类似于Pytorch中的BCEWithLogitsLoss中的pos_weight参数含义。PS:暂时没有实现,所以还是老老实实没一个样本设置一个权值吧。

另一方面,值得注意的是,在图像分割任务中,会碰到样本mask中没有正样本的情况。例如在SIIM-ACR Pneumothorax Segmentation比赛中,就会出现大部分图像中并没有目标,mask也就全部为负样本。对于mask全部为负样本的数据,若预测出mask也没正样本,上面的dice_loss函数分子接近0,导致最终的loss很大,然而真实情况应该为此时loss应该很小。因此可以考虑下面的dice函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# dice for threshold selection
def dice_overall(self, preds, targs):
n = preds.shape[0] # batch size为多少
preds = preds.view(n, -1)
targs = targs.view(n, -1)
# preds, targs = preds.to(self.device), targs.to(self.device)
preds, targs = preds.cpu(), targs.cpu()

# tensor之间按位相成,求两个集合的交(只有1×1等于1)后。按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的交集大小
intersect = (preds * targs).sum(-1).float()
# tensor之间按位相加,求两个集合的并。然后按照第二个维度求和,得到[batch size]大小的tensor,每一个值代表该输入图片真实类标与预测类标的并集大小
union = (preds + targs).sum(-1).float()
'''
输入图片真实类标与预测类标无并集有两种情况:第一种为预测与真实均没有类标,此时并集之和为0;第二种为真实有类标,但是预测完全错误,此时并集之和不为0;

寻找输入图片真实类标与预测类标并集之和为0的情况,将其交集置为1,并集置为2,最后还有一个2*交集/并集,值为1;
其余情况,直接按照2*交集/并集计算,因为上面的并集并没有减去交集,所以需要拿2*交集,其最大值为1
'''
u0 = union == 0
intersect[u0] = 1
union[u0] = 2

return (2. * intersect / union)

那么如何实现对所有样本的所有像素点分配权重呢?这需要引入下面的SoftDICELoss

SoftDICELoss

这个loss是一个kaggle的大神提出来的。该损失函数克服了上面DiceLoss损失函数没有考虑

DICE还有另外一个形式:

其中,$\mathbf{p} \in \{0,1\}^n$,$\mathbf{\hat p} \in [0,1]^n$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# reference https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/101429#latest-588288
class SoftDiceLoss(nn.Module):
"""二分类加权dice损失
"""
def __init__(self, size_average=True, weight=[0.2, 0.8]):
"""
weight: 各类别权重
"""
super(SoftDiceLoss, self).__init__()
self.size_average = size_average
self.weight = torch.FloatTensor(weight)

def forward(self, logit_pixel, truth_pixel):
batch_size = len(logit_pixel)
logit = logit_pixel.view(batch_size, -1)
truth = truth_pixel.view(batch_size, -1)
assert(logit.shape == truth.shape)

loss = self.soft_dice_criterion(logit, truth)

if self.size_average:
loss = loss.mean()
return loss

def soft_dice_criterion(self, logit, truth):
batch_size = len(logit)
probability = torch.sigmoid(logit)

p = probability.view(batch_size, -1)
t = truth.view(batch_size, -1)
# 向各样本所有像素点分配所属类别的权重,此时w只有0或1两个值
w = truth.detach()
self.weight = self.weight.type_as(logit)
# * 和 -1 均为像素点之间的运算,默认此时负样本处w=0.2和正样本处w=0.8
w = w * (self.weight[1] - self.weight[0]) + self.weight[0]

p = w * (p*2 - 1) #convert to [0,1] --> [-1, 1]
t = w * (t*2 - 1)

intersection = (p * t).sum(-1)
union = (p * p).sum(-1) + (t * t).sum(-1)
dice = 1 - 2 * intersection/union

loss = dice
return loss

解释下上面代码,因为考虑到数据集中可能存在某些样本的掩模均为负样本,没有正样本的情况,所以需要将类标从[0,1]变为[-1,1]。此时,若真实掩模没有mask,预测出来也全部没有mask,不会因为全部值为0,导致dice的分子为0,loss为1。相反此时全部值为-1,dice的值为1,loss为0,更符合我们的实际需求。

另外,所谓的使用加权的损失函数解决样本不均衡问题,是指对于每一个正样本和负样本均有对应的加权系数。上面代码可以总结为公式:

其中,$t_{i,j} \in {-1,1}$,而$p_{i,j} \in [-1,1]$。若正样本的系数$w_{i,j}$为0.8,而负样本的系数$w_{i,j}$为0.2,则正样本对dice的影响更大,负样本对dice的影响更小。从而让网络更加关注正样本。

FocalLoss

该损失函数降低easy examples的权重,使得模型更加关注hard examples

其中$\gamma$为超参数,当$\gamma = 0$的时候,我们得到标准BCE。我们这里关注的为当$\gamma \not= 0$的时候,对于这个公式的理解如下。

当$\gamma>1$时:

  • 当样本为正样本时,此时上式右边只有第一项不为0。若$\hat{p}$较大的时候,意味着网络对该数据的分类效果较好,$(1 - \hat{p})^{\gamma}$值较小,意味着该数据的loss更小,网络接下来对于该数据的关注会更小;反之,当$\hat{p}$较小的时候,意味着网络对该数据的分类效果较差,$(1 - \hat{p})^{\gamma}$值较大,意味着该数据的loss更大,网络接下来对于该数据的关注会更大。
  • 当样本为负样本时,此时上式右边只有第二项不为0。若$\hat{p}$较大的时候,意味着网络对该数据的分类效果较差,$\hat{p}^{\gamma}$值较大,意味着该数据的loss更大,网络接下来对于该数据的关注会更大;反之,当$\hat{p}$较小的时候,意味着网络对该数据的分类效果较好,$\hat{p}^{\gamma}$值较小,意味着该数据的loss更小,网络接下来对于该数据的关注会更小。

当$\gamma<1$的时候,此时损失函数会越加关注容易分的样本,而越加不关注难分的样本,与该损失函数的设计初衷背道而驰。所以实际使用的时候,$\gamma\geq1$。

因为我们这里使用的是logistic/sigmoid 函数预测的,所以继续进行推导,可以得到

参考

Losses for Image Segmentation
BCEWithLogitsLoss
losses.py
some workable loss function
How to apply weighted loss to a binary segmentation problem?

------ 本文结束------
坚持原创技术分享,您的支持将鼓励我继续创作!

欢迎关注我的其它发布渠道