-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add MCC loss #8785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add MCC loss #8785
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
| from torch.nn.modules.loss import _Loss | ||
|
|
||
| from monai.networks import one_hot | ||
| from monai.utils import LossReduction | ||
|
|
||
|
|
||
| class MCCLoss(_Loss): | ||
| """ | ||
| Compute the Matthews Correlation Coefficient (MCC) loss between two tensors. | ||
|
|
||
| Unlike Dice and Tversky losses which only use TP, FP, and FN, the MCC loss considers all four entries | ||
| of the confusion matrix (TP, TN, FP, FN), making it effective for class-imbalanced segmentation tasks | ||
| where background dominates the image. The loss is computed as ``1 - MCC`` where | ||
| ``MCC = (TP * TN - FP * FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))``. | ||
|
|
||
| The soft confusion matrix entries are computed as: | ||
|
|
||
| - ``TP = sum(input * target)`` | ||
| - ``TN = sum((1 - input) * (1 - target))`` | ||
| - ``FP = sum(input * (1 - target))`` | ||
| - ``FN = sum((1 - input) * target)`` | ||
|
|
||
| The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). | ||
|
|
||
| Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input, | ||
| must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target` | ||
| can be 1 or N (one-hot format). | ||
|
|
||
| The original paper: | ||
|
|
||
| Abhishek, K. and Hamarneh, G. (2021) Matthews Correlation Coefficient Loss for Deep Convolutional | ||
| Networks: Application to Skin Lesion Segmentation. IEEE ISBI, pp. 225-229. | ||
| (https://doi.org/10.1109/ISBI48211.2021.9433782) | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| include_background: bool = True, | ||
| to_onehot_y: bool = False, | ||
| sigmoid: bool = False, | ||
| softmax: bool = False, | ||
| other_act: Callable | None = None, | ||
| reduction: LossReduction | str = LossReduction.MEAN, | ||
| smooth_nr: float = 0.0, | ||
| smooth_dr: float = 1e-5, | ||
| batch: bool = False, | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| include_background: if False, channel index 0 (background category) is excluded from the calculation. | ||
| if the non-background segmentations are small compared to the total image size they can get | ||
| overwhelmed by the signal from the background so excluding it in such cases helps convergence. | ||
| to_onehot_y: whether to convert the ``target`` into the one-hot format, | ||
| using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. | ||
| sigmoid: if True, apply a sigmoid function to the prediction. | ||
| softmax: if True, apply a softmax function to the prediction. | ||
| other_act: callable function to execute other activation layers, Defaults to ``None``. for example: | ||
| ``other_act = torch.tanh``. | ||
| reduction: {``"none"``, ``"mean"``, ``"sum"``} | ||
| Specifies the reduction to apply to the output. Defaults to ``"mean"``. | ||
|
|
||
| - ``"none"``: no reduction will be applied. | ||
| - ``"mean"``: the sum of the output will be divided by the number of elements in the output. | ||
| - ``"sum"``: the output will be summed. | ||
|
|
||
| smooth_nr: a small constant added to the numerator to avoid zero. | ||
| smooth_dr: a small constant added to the denominator to avoid nan. | ||
| batch: whether to sum the confusion matrix entries over the batch dimension before computing MCC. | ||
| Defaults to False, MCC is computed independently for each item in the batch | ||
| before any `reduction`. | ||
|
|
||
| Raises: | ||
| TypeError: When ``other_act`` is not an ``Optional[Callable]``. | ||
| ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. | ||
| Incompatible values. | ||
|
|
||
| """ | ||
| super().__init__(reduction=LossReduction(reduction).value) | ||
| if other_act is not None and not callable(other_act): | ||
| raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") | ||
| if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: | ||
| raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") | ||
| self.include_background = include_background | ||
| self.to_onehot_y = to_onehot_y | ||
| self.sigmoid = sigmoid | ||
| self.softmax = softmax | ||
| self.other_act = other_act | ||
| self.smooth_nr = float(smooth_nr) | ||
| self.smooth_dr = float(smooth_dr) | ||
| self.batch = batch | ||
|
|
||
| def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| input: the shape should be BNH[WD], where N is the number of classes. | ||
| target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. | ||
|
|
||
| Raises: | ||
| AssertionError: When input and target (after one hot transform if set) | ||
| have different shapes. | ||
| ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. | ||
|
|
||
| Example: | ||
| >>> from monai.losses.mcc_loss import MCCLoss | ||
| >>> import torch | ||
| >>> B, C, H, W = 7, 1, 3, 2 | ||
| >>> input = torch.rand(B, C, H, W) | ||
| >>> target = torch.randint(low=0, high=2, size=(B, C, H, W)).float() | ||
| >>> self = MCCLoss(reduction='none') | ||
| >>> loss = self(input, target) | ||
| """ | ||
| if self.sigmoid: | ||
| input = torch.sigmoid(input) | ||
|
|
||
| n_pred_ch = input.shape[1] | ||
| if self.softmax: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `softmax=True` ignored.") | ||
| else: | ||
| input = torch.softmax(input, 1) | ||
|
|
||
| if self.other_act is not None: | ||
| input = self.other_act(input) | ||
|
|
||
| if self.to_onehot_y: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") | ||
| else: | ||
| target = one_hot(target, num_classes=n_pred_ch) | ||
|
|
||
| if not self.include_background: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `include_background=False` ignored.") | ||
| else: | ||
| target = target[:, 1:] | ||
| input = input[:, 1:] | ||
|
|
||
| if target.shape != input.shape: | ||
| raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") | ||
|
|
||
| # reducing only spatial dimensions (not batch nor channels) | ||
| reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() | ||
| if self.batch: | ||
| reduce_axis = [0] + reduce_axis | ||
|
|
||
| # Soft confusion matrix entries (Eq. 5 in the paper). | ||
| tp = torch.sum(input * target, dim=reduce_axis) | ||
| tn = torch.sum((1.0 - input) * (1.0 - target), dim=reduce_axis) | ||
| fp = torch.sum(input * (1.0 - target), dim=reduce_axis) | ||
| fn = torch.sum((1.0 - input) * target, dim=reduce_axis) | ||
|
|
||
| # MCC (Eq. 3) and loss (Eq. 4). | ||
| numerator = tp * tn - fp * fn + self.smooth_nr | ||
| denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + self.smooth_dr) | ||
|
|
||
| mcc = numerator / denominator | ||
| score: torch.Tensor = 1.0 - mcc | ||
|
|
||
| # When fp = fn = 0, prediction is perfect but the denominator product | ||
| # tends to 0 when tp = 0 or tn = 0, giving mcc ~ 0 instead of 1. | ||
| perfect = (fp == 0) & (fn == 0) | ||
| score = torch.where(perfect, torch.zeros_like(score), score) | ||
|
|
||
| if self.reduction == LossReduction.SUM.value: | ||
| return torch.sum(score) | ||
| if self.reduction == LossReduction.NONE.value: | ||
| return score | ||
| if self.reduction == LossReduction.MEAN.value: | ||
| return torch.mean(score) | ||
| raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import unittest | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from parameterized import parameterized | ||
|
|
||
| from monai.losses import MCCLoss | ||
| from tests.test_utils import test_script_save | ||
|
|
||
| TEST_CASES = [ | ||
| [ # shape: (1, 1, 2, 2), (1, 1, 2, 2), sigmoid | ||
| {"include_background": True, "sigmoid": True}, | ||
| {"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])}, | ||
| 0.733197, | ||
| ], | ||
| [ # shape: (2, 1, 2, 2), (2, 1, 2, 2), sigmoid | ||
| {"include_background": True, "sigmoid": True}, | ||
| { | ||
| "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), | ||
| "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), | ||
| }, | ||
| 1.0, | ||
| ], | ||
| [ # shape: (1, 1, 2, 2), (1, 1, 2, 2), perfect prediction | ||
| {"include_background": True}, | ||
| {"input": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])}, | ||
| 0.0, | ||
| ], | ||
| [ # shape: (1, 1, 2, 2), (1, 1, 2, 2), worst case (inverted) | ||
| {"include_background": True}, | ||
| {"input": torch.tensor([[[[0.0, 1.0], [1.0, 0.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])}, | ||
| 2.0, | ||
| ], | ||
| [ # shape: (2, 2, 3), (2, 1, 3), multi-class, exclude background, one-hot | ||
| {"include_background": False, "to_onehot_y": True}, | ||
| { | ||
| "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]), | ||
| "target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]), | ||
| }, | ||
| 0.0, | ||
| ], | ||
| [ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot | ||
| {"include_background": True, "to_onehot_y": True, "sigmoid": True}, | ||
| { | ||
| "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), | ||
| "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), | ||
| }, | ||
| 0.836617, | ||
| ], | ||
| [ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot, batch=True | ||
| {"include_background": True, "to_onehot_y": True, "sigmoid": True, "batch": True}, | ||
| { | ||
| "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), | ||
| "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), | ||
| }, | ||
| 0.845961, | ||
| ], | ||
| [ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot, reduction=sum | ||
| {"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "sum"}, | ||
| { | ||
| "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), | ||
| "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), | ||
| }, | ||
| 3.346468, | ||
| ], | ||
| [ # shape: (2, 2, 3), (2, 1, 3), multi-class, softmax, one-hot | ||
| {"include_background": True, "to_onehot_y": True, "softmax": True}, | ||
| { | ||
| "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), | ||
| "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), | ||
| }, | ||
| 0.730736, | ||
| ], | ||
| [ # shape: (2, 2, 3), (2, 1, 3), multi-class, softmax, one-hot, reduction=none | ||
| {"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "none"}, | ||
| { | ||
| "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), | ||
| "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), | ||
| }, | ||
| [[0.461472, 0.461472], [1.0, 1.0]], | ||
| ], | ||
| [ # shape: (1, 1, 3, 3), (1, 1, 3, 3), all-ones perfect prediction | ||
| {"include_background": True}, | ||
| {"input": torch.ones(1, 1, 3, 3), "target": torch.ones(1, 1, 3, 3)}, | ||
| 0.0, | ||
| ], | ||
| [ # shape: (1, 1, 3, 3), (1, 1, 3, 3), all-zeros perfect prediction | ||
| {"include_background": True}, | ||
| {"input": torch.zeros(1, 1, 3, 3), "target": torch.zeros(1, 1, 3, 3)}, | ||
| 0.0, | ||
| ], | ||
| [ # shape: (2, 1, 2, 2), (2, 1, 2, 2), other_act=torch.tanh | ||
| {"include_background": True, "other_act": torch.tanh}, | ||
| { | ||
| "input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]), | ||
| "target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]), | ||
| }, | ||
| 1.0, | ||
| ], | ||
| ] | ||
|
|
||
|
|
||
| class TestMCCLoss(unittest.TestCase): | ||
| @parameterized.expand(TEST_CASES) | ||
| def test_shape(self, input_param, input_data, expected_val): | ||
| result = MCCLoss(**input_param).forward(**input_data) | ||
| np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4) | ||
|
|
||
| def test_ill_shape(self): | ||
| loss = MCCLoss() | ||
| with self.assertRaisesRegex(AssertionError, ""): | ||
| loss.forward(torch.ones((2, 2, 3)), torch.ones((4, 5, 6))) | ||
| chn_input = torch.ones((1, 1, 3)) | ||
| chn_target = torch.ones((1, 1, 3)) | ||
| with self.assertRaisesRegex(ValueError, ""): | ||
| MCCLoss(reduction="unknown")(chn_input, chn_target) | ||
| with self.assertRaisesRegex(ValueError, ""): | ||
| MCCLoss(reduction=None)(chn_input, chn_target) | ||
|
|
||
| def test_ill_opts(self): | ||
| with self.assertRaisesRegex(ValueError, ""): | ||
| MCCLoss(sigmoid=True, softmax=True) | ||
| with self.assertRaisesRegex(TypeError, ""): | ||
| MCCLoss(other_act="tanh") | ||
|
|
||
| @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)]) | ||
| def test_input_warnings(self, include_background, softmax, to_onehot_y): | ||
| chn_input = torch.ones((1, 1, 3)) | ||
| chn_target = torch.ones((1, 1, 3)) | ||
| with self.assertWarns(Warning): | ||
| loss = MCCLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y) | ||
| loss.forward(chn_input, chn_target) | ||
|
|
||
| def test_script(self): | ||
| loss = MCCLoss() | ||
| test_input = torch.ones(2, 1, 8, 8) | ||
| test_script_save(loss, test_input, test_input) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.