You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fixes#8094.
### Description
The Dice, Jaccard and Tversky losses in `monai.losses.dice` and
`monai.losses.tversky` are modified based on
[JDTLoss](https://github.com/zifuwanggg/JDTLosses/blob/master/losses/jdt_loss.py)
and
[segmentation_models.pytorch](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/losses/_functional.py).
In the original versions, when `squared_pred=False`, the loss functions
are incompatible with soft labels. For example, with a ground truth
value of 0.5 for a single pixel, the Dice loss is minimized when the
predicted value is 1, which is clearly erroneous. To address this, the
intersection term is rewritten as $\frac{\|x\|_p^p + \|y\|_p^p -
\|x-y\|_p^p}{2}$. When $p$ is 2 (`squared_pred=True`), this
reformulation becomes the classical inner product: $\langle x,y
\rangle$. When $p$ is 1 (`squared_pred=False`), the reformulation has
been proven to retain equivalence with the original versions when the
ground truth is binary (i.e. one-hot hard labels). Moreover, since the
new versions are minimized if and only if the prediction is identical to
the ground truth, even when the ground truth include fractional numbers,
they resolves the issue with soft labels [1, 2].
In summary, there are three scenarios:
* [Scenario 1] $x$ is nonnegative and $y$ is binary: The new versions
are the same as the original versions.
* [Scenario 2] Both $x$ and $y$ are nonnegative: The new versions differ
from the original versions. The new versions are minimized if and only
if $x=y$, while the original versions may not, making them incorrect.
* [Scenario 3] Either $x$ or $y$ is negative: The new versions differ
from the original versions. The new versions are minimized if and only
if $x=y$, while the original versions may not, making them incorrect.
Due to these differences, particularly in Scenarios 2 and 3, some tests
fail with the new versions:
* The target is non-binary: `test_multi_scale`
* The input is negative: `test_dice_loss`, `test_tversky_loss`,
`test_generalized_dice_loss`, `test_masked_loss`,
`test_seg_loss_integration`
The failures in `test_multi_scale` are expected since the original
versions are incorrectly defined for non-binary targets. Furthermore,
because Dice, Jaccard, and Tversky losses are fundamentally defined over
probabilities—which should be nonnegative—the new versions should not be
tested against negative input or target values.
### Example
```
import torch
import torch.linalg as LA
import torch.nn.functional as F
torch.manual_seed(0)
b, c, h, w = 4, 3, 32, 32
dims = (0, 2, 3)
pred = torch.rand(b, c, h, w).softmax(dim=1)
soft_label = torch.rand(b, c, h, w).softmax(dim=1)
hard_label = torch.randint(low=0, high=c, size=(b, h, w))
one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2).float()
def dice_old(x, y, ord, dims):
cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord
intersection = torch.sum(x * y, dim=dims)
return 2 * intersection / cardinality
def dice_new(x, y, ord, dims):
cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord
difference = LA.vector_norm(x - y, ord=ord, dim=dims) ** ord
intersection = (cardinality - difference) / 2
return 2 * intersection / cardinality
print(dice_old(pred, one_hot_label, 1, dims), dice_new(pred, one_hot_label, 1, dims))
print(dice_old(pred, soft_label, 1, dims), dice_new(pred, soft_label, 1, dims))
print(dice_old(pred, pred, 1, dims), dice_new(pred, pred, 1, dims))
print(dice_old(pred, one_hot_label, 2, dims), dice_new(pred, one_hot_label, 2, dims))
print(dice_old(pred, soft_label, 2, dims), dice_new(pred, soft_label, 2, dims))
print(dice_old(pred, pred, 2, dims), dice_new(pred, pred, 2, dims))
# tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317])
# tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700])
# tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.])
# tensor([0.4921, 0.4904, 0.4935]) tensor([0.4921, 0.4904, 0.4935])
# tensor([0.9489, 0.9499, 0.9503]) tensor([0.9489, 0.9499, 0.9503])
# tensor([1., 1., 1.]) tensor([1., 1., 1.])
```
### References
[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels.
Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew
B. Blaschko. *MICCAI 2023*.
[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft
Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. *NeurIPS 2023*.
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Zifu Wang <zifuwang94@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
0 commit comments