Skip to content

Commit 9808ce2

Browse files
zifuwangggKumoLiupre-commit-ci[bot]
authored
Modify Dice, Jaccard and Tversky losses (#8138)
Fixes #8094. ### Description The Dice, Jaccard and Tversky losses in `monai.losses.dice` and `monai.losses.tversky` are modified based on [JDTLoss](https://github.com/zifuwanggg/JDTLosses/blob/master/losses/jdt_loss.py) and [segmentation_models.pytorch](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/losses/_functional.py). In the original versions, when `squared_pred=False`, the loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{\|x\|_p^p + \|y\|_p^p - \|x-y\|_p^p}{2}$. When $p$ is 2 (`squared_pred=True`), this reformulation becomes the classical inner product: $\langle x,y \rangle$. When $p$ is 1 (`squared_pred=False`), the reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new versions are minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, they resolves the issue with soft labels [1, 2]. In summary, there are three scenarios: * [Scenario 1] $x$ is nonnegative and $y$ is binary: The new versions are the same as the original versions. * [Scenario 2] Both $x$ and $y$ are nonnegative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect. * [Scenario 3] Either $x$ or $y$ is negative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect. Due to these differences, particularly in Scenarios 2 and 3, some tests fail with the new versions: * The target is non-binary: `test_multi_scale` * The input is negative: `test_dice_loss`, `test_tversky_loss`, `test_generalized_dice_loss`, `test_masked_loss`, `test_seg_loss_integration` The failures in `test_multi_scale` are expected since the original versions are incorrectly defined for non-binary targets. Furthermore, because Dice, Jaccard, and Tversky losses are fundamentally defined over probabilities—which should be nonnegative—the new versions should not be tested against negative input or target values. ### Example ``` import torch import torch.linalg as LA import torch.nn.functional as F torch.manual_seed(0) b, c, h, w = 4, 3, 32, 32 dims = (0, 2, 3) pred = torch.rand(b, c, h, w).softmax(dim=1) soft_label = torch.rand(b, c, h, w).softmax(dim=1) hard_label = torch.randint(low=0, high=c, size=(b, h, w)) one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2).float() def dice_old(x, y, ord, dims): cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord intersection = torch.sum(x * y, dim=dims) return 2 * intersection / cardinality def dice_new(x, y, ord, dims): cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord difference = LA.vector_norm(x - y, ord=ord, dim=dims) ** ord intersection = (cardinality - difference) / 2 return 2 * intersection / cardinality print(dice_old(pred, one_hot_label, 1, dims), dice_new(pred, one_hot_label, 1, dims)) print(dice_old(pred, soft_label, 1, dims), dice_new(pred, soft_label, 1, dims)) print(dice_old(pred, pred, 1, dims), dice_new(pred, pred, 1, dims)) print(dice_old(pred, one_hot_label, 2, dims), dice_new(pred, one_hot_label, 2, dims)) print(dice_old(pred, soft_label, 2, dims), dice_new(pred, soft_label, 2, dims)) print(dice_old(pred, pred, 2, dims), dice_new(pred, pred, 2, dims)) # tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317]) # tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700]) # tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.]) # tensor([0.4921, 0.4904, 0.4935]) tensor([0.4921, 0.4904, 0.4935]) # tensor([0.9489, 0.9499, 0.9503]) tensor([0.9489, 0.9499, 0.9503]) # tensor([1., 1., 1.]) tensor([1., 1., 1.]) ``` ### References [1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. *MICCAI 2023*. [2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. *NeurIPS 2023*. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Zifu Wang <zifuwang94@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7f88a46 commit 9808ce2

File tree

6 files changed

+160
-30
lines changed

6 files changed

+160
-30
lines changed

monai/losses/dice.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from monai.losses.focal_loss import FocalLoss
2525
from monai.losses.spatial_mask import MaskedLoss
26+
from monai.losses.utils import compute_tp_fp_fn
2627
from monai.networks import one_hot
2728
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
2829

@@ -39,8 +40,16 @@ class DiceLoss(_Loss):
3940
The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
4041
the inter-over-union calculation to smooth results respectively, these values should be small.
4142
42-
The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric
43-
Medical Image Segmentation, 3DV, 2016.
43+
The original papers:
44+
45+
Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric
46+
Medical Image Segmentation. 3DV 2016.
47+
48+
Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with
49+
Soft Labels. NeurIPS 2023.
50+
51+
Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
52+
Soft Labels. MICCAI 2023.
4453
4554
"""
4655

@@ -58,6 +67,7 @@ def __init__(
5867
smooth_dr: float = 1e-5,
5968
batch: bool = False,
6069
weight: Sequence[float] | float | int | torch.Tensor | None = None,
70+
soft_label: bool = False,
6171
) -> None:
6272
"""
6373
Args:
@@ -89,6 +99,8 @@ def __init__(
8999
of the sequence should be the same as the number of classes. If not ``include_background``,
90100
the number of classes should not include the background category class 0).
91101
The value/values should be no less than 0. Defaults to None.
102+
soft_label: whether the target contains non-binary values (soft labels) or not.
103+
If True a soft label formulation of the loss will be used.
92104
93105
Raises:
94106
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -114,6 +126,7 @@ def __init__(
114126
weight = torch.as_tensor(weight) if weight is not None else None
115127
self.register_buffer("class_weight", weight)
116128
self.class_weight: None | torch.Tensor
129+
self.soft_label = soft_label
117130

118131
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
119132
"""
@@ -174,21 +187,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
174187
# reducing spatial dimensions and batch
175188
reduce_axis = [0] + reduce_axis
176189

177-
intersection = torch.sum(target * input, dim=reduce_axis)
178-
179-
if self.squared_pred:
180-
ground_o = torch.sum(target**2, dim=reduce_axis)
181-
pred_o = torch.sum(input**2, dim=reduce_axis)
182-
else:
183-
ground_o = torch.sum(target, dim=reduce_axis)
184-
pred_o = torch.sum(input, dim=reduce_axis)
185-
186-
denominator = ground_o + pred_o
187-
188-
if self.jaccard:
189-
denominator = 2.0 * (denominator - intersection)
190+
ord = 2 if self.squared_pred else 1
191+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)
192+
if not self.jaccard:
193+
fp *= 0.5
194+
fn *= 0.5
195+
numerator = 2 * tp + self.smooth_nr
196+
denominator = 2 * (tp + fp + fn) + self.smooth_dr
190197

191-
f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)
198+
f: torch.Tensor = 1 - numerator / denominator
192199

193200
num_of_classes = target.shape[1]
194201
if self.class_weight is not None and num_of_classes != 1:
@@ -272,6 +279,7 @@ def __init__(
272279
smooth_nr: float = 1e-5,
273280
smooth_dr: float = 1e-5,
274281
batch: bool = False,
282+
soft_label: bool = False,
275283
) -> None:
276284
"""
277285
Args:
@@ -295,6 +303,8 @@ def __init__(
295303
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
296304
Defaults to False, intersection over union is computed from each item in the batch.
297305
If True, the class-weighted intersection and union areas are first summed across the batches.
306+
soft_label: whether the target contains non-binary values (soft labels) or not.
307+
If True a soft label formulation of the loss will be used.
298308
299309
Raises:
300310
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -319,6 +329,7 @@ def __init__(
319329
self.smooth_nr = float(smooth_nr)
320330
self.smooth_dr = float(smooth_dr)
321331
self.batch = batch
332+
self.soft_label = soft_label
322333

323334
def w_func(self, grnd):
324335
if self.w_type == str(Weight.SIMPLE):
@@ -370,13 +381,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
370381
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
371382
if self.batch:
372383
reduce_axis = [0] + reduce_axis
373-
intersection = torch.sum(target * input, reduce_axis)
374384

375-
ground_o = torch.sum(target, reduce_axis)
376-
pred_o = torch.sum(input, reduce_axis)
377-
378-
denominator = ground_o + pred_o
385+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label)
386+
fp *= 0.5
387+
fn *= 0.5
388+
denominator = 2 * (tp + fp + fn)
379389

390+
ground_o = torch.sum(target, reduce_axis)
380391
w = self.w_func(ground_o.float())
381392
infs = torch.isinf(w)
382393
if self.batch:
@@ -388,7 +399,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
388399
w = w + infs * max_values
389400

390401
final_reduce_dim = 0 if self.batch else 1
391-
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
402+
numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
392403
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
393404
f: torch.Tensor = 1.0 - (numer / denom)
394405

monai/losses/tversky.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from torch.nn.modules.loss import _Loss
1919

20+
from monai.losses.utils import compute_tp_fp_fn
2021
from monai.networks import one_hot
2122
from monai.utils import LossReduction
2223

@@ -28,6 +29,9 @@ class TverskyLoss(_Loss):
2829
Sadegh et al. (2017) Tversky loss function for image segmentation
2930
using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721)
3031
32+
Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
33+
Soft Labels. MICCAI 2023.
34+
3135
Adapted from:
3236
https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631
3337
@@ -46,6 +50,7 @@ def __init__(
4650
smooth_nr: float = 1e-5,
4751
smooth_dr: float = 1e-5,
4852
batch: bool = False,
53+
soft_label: bool = False,
4954
) -> None:
5055
"""
5156
Args:
@@ -70,6 +75,8 @@ def __init__(
7075
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
7176
Defaults to False, a Dice loss value is computed independently from each item in the batch
7277
before any `reduction`.
78+
soft_label: whether the target contains non-binary values (soft labels) or not.
79+
If True a soft label formulation of the loss will be used.
7380
7481
Raises:
7582
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -93,6 +100,7 @@ def __init__(
93100
self.smooth_nr = float(smooth_nr)
94101
self.smooth_dr = float(smooth_dr)
95102
self.batch = batch
103+
self.soft_label = soft_label
96104

97105
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
98106
"""
@@ -134,20 +142,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
134142
if target.shape != input.shape:
135143
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
136144

137-
p0 = input
138-
p1 = 1 - p0
139-
g0 = target
140-
g1 = 1 - g0
141-
142145
# reducing only spatial dimensions (not batch nor channels)
143146
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
144147
if self.batch:
145148
# reducing spatial dimensions and batch
146149
reduce_axis = [0] + reduce_axis
147150

148-
tp = torch.sum(p0 * g0, reduce_axis)
149-
fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
150-
fn = self.beta * torch.sum(p1 * g0, reduce_axis)
151+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False)
152+
fp *= self.alpha
153+
fn *= self.beta
151154
numerator = tp + self.smooth_nr
152155
denominator = tp + fp + fn + self.smooth_dr
153156

monai/losses/utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import torch
15+
import torch.linalg as LA
16+
17+
18+
def compute_tp_fp_fn(
19+
input: torch.Tensor,
20+
target: torch.Tensor,
21+
reduce_axis: list[int],
22+
ord: int,
23+
soft_label: bool,
24+
decoupled: bool = True,
25+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
26+
"""
27+
Args:
28+
input: the shape should be BNH[WD], where N is the number of classes.
29+
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
30+
reduce_axis: the axis to be reduced.
31+
ord: the order of the vector norm.
32+
soft_label: whether the target contains non-binary values (soft labels) or not.
33+
If True a soft label formulation of the loss will be used.
34+
decoupled: whether the input and the target should be decoupled when computing fp and fn.
35+
Only for the original implementation when soft_label is False.
36+
37+
Adapted from:
38+
https://github.com/zifuwanggg/JDTLosses
39+
"""
40+
41+
# the original implementation that is erroneous with soft labels
42+
if ord == 1 and not soft_label:
43+
tp = torch.sum(input * target, dim=reduce_axis)
44+
# the original implementation of Dice and Jaccard loss
45+
if decoupled:
46+
fp = torch.sum(input, dim=reduce_axis) - tp
47+
fn = torch.sum(target, dim=reduce_axis) - tp
48+
# the original implementation of Tversky loss
49+
else:
50+
fp = torch.sum(input * (1 - target), dim=reduce_axis)
51+
fn = torch.sum((1 - input) * target, dim=reduce_axis)
52+
# the new implementation that is correct with soft labels
53+
# and it is identical to the original implementation with hard labels
54+
else:
55+
pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)
56+
ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)
57+
difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis)
58+
59+
if ord > 1:
60+
pred_o = torch.pow(pred_o, exponent=ord)
61+
ground_o = torch.pow(ground_o, exponent=ord)
62+
difference = torch.pow(difference, exponent=ord)
63+
64+
tp = (pred_o + ground_o - difference) / 2
65+
fp = pred_o - tp
66+
fn = ground_o - tp
67+
68+
return tp, fp, fn

tests/test_dice_loss.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
},
3535
0.416657,
3636
],
37+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
38+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
39+
{
40+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
41+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
42+
},
43+
0.0,
44+
],
45+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
46+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
47+
{
48+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
49+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
50+
},
51+
0.307773,
52+
],
3753
[ # shape: (2, 2, 3), (2, 1, 3)
3854
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
3955
{

tests/test_generalized_dice_loss.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
},
3535
0.416597,
3636
],
37+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
38+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
39+
{
40+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
41+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
42+
},
43+
0.0,
44+
],
45+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
46+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
47+
{
48+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
49+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
50+
},
51+
0.307748,
52+
],
3753
[ # shape: (2, 2, 3), (2, 1, 3)
3854
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0},
3955
{

tests/test_tversky_loss.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@
3434
},
3535
0.416657,
3636
],
37+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
38+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
39+
{
40+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
41+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
42+
},
43+
0.0,
44+
],
45+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
46+
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
47+
{
48+
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
49+
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
50+
},
51+
0.307773,
52+
],
3753
[ # shape: (2, 2, 3), (2, 1, 3)
3854
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
3955
{

0 commit comments

Comments
 (0)