From b5107da16ad3223ebbd32b896a60c288d2653546 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Thu, 12 Mar 2026 16:56:16 -0400 Subject: [PATCH 1/8] Add parameter to DiceMetric and DiceHelper classes Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/meandice.py | 88 +++++++++++++++++++++++++- tests/metrics/test_compute_meandice.py | 38 +++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index fedd94fb93..d61d07475c 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -11,7 +11,11 @@ from __future__ import annotations +import numpy as np import torch +import torch.nn.functional as F +from scipy.ndimage import distance_transform_edt, generate_binary_structure +from scipy.ndimage import label as sn_label from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction, deprecated_arg @@ -106,6 +110,7 @@ def __init__( ignore_empty: bool = True, num_classes: int | None = None, return_with_label: bool | list[str] = False, + per_component: bool = False, ) -> None: super().__init__() self.include_background = include_background @@ -114,6 +119,7 @@ def __init__( self.ignore_empty = ignore_empty self.num_classes = num_classes self.return_with_label = return_with_label + self.per_component = per_component self.dice_helper = DiceHelper( include_background=self.include_background, reduction=MetricReduction.NONE, @@ -121,6 +127,7 @@ def __init__( apply_argmax=False, ignore_empty=self.ignore_empty, num_classes=self.num_classes, + per_component=self.per_component, ) def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] @@ -175,6 +182,7 @@ def compute_dice( include_background: bool = True, ignore_empty: bool = True, num_classes: int | None = None, + per_component: bool = False, ) -> torch.Tensor: """ Computes Dice score metric for a batch of predictions. This performs the same computation as @@ -204,6 +212,7 @@ def compute_dice( apply_argmax=False, ignore_empty=ignore_empty, num_classes=num_classes, + per_component=per_component, )(y_pred=y_pred, y=y) @@ -262,6 +271,7 @@ def __init__( num_classes: int | None = None, sigmoid: bool | None = None, softmax: bool | None = None, + per_component: bool = False, ) -> None: # handling deprecated arguments if sigmoid is not None: @@ -277,6 +287,73 @@ def __init__( self.activate = activate self.ignore_empty = ignore_empty self.num_classes = num_classes + self.per_component = per_component + + def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): + """ + Voronoi assignment to connected components (CPU, single EDT) without cc3d. + Returns the ID of the nearest component for each voxel. + + Args: + labels: input label map as a numpy array, where values > 0 are considered seeds for connected components. + connectivity: 6/18/26 (3D) + sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt) + """ + + x = np.asarray(labels) + conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) + structure = generate_binary_structure(rank=3, connectivity=conn_rank) + cc, num = sn_label(x > 0, structure=structure) + if num == 0: + return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) + edt_input = np.ones(cc.shape, dtype=np.uint8) + edt_input[cc > 0] = 0 + indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) + voronoi = cc[tuple(indices)] + return torch.from_numpy(voronoi) + + def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately + for each batch item and for each channel of those items. + + Args: + y_pred: input predictions with shape HW[D]. + y: ground truth with shape HW[D]. + """ + data = [] + if y_pred.ndim == y.ndim: + y_pred_idx = torch.argmax(y_pred, dim=1) + y_idx = torch.argmax(y, dim=1) + else: + y_pred_idx = y_pred + y_idx = y + if y_idx[0].sum() == 0: + if y_pred_idx.sum() == 0: + data.append(torch.tensor(1.0, device=y_idx.device)) + else: + data.append(torch.tensor(0.0, device=y_idx.device)) + else: + cc_assignment = self.compute_voronoi_regions_fast(y_idx[0]) + uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True) + nof_components = uniq.numel() + code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1) + idx = (inv << 2) | code + hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4) + _, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3] + denom = 2 * tp + fp + fn + dice_scores = torch.where( + denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device) + ) + data.append(dice_scores.unsqueeze(-1)) + data = [ + torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data + ] + data = [ + torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data + ] + data = [x.reshape(-1, 1) for x in data] + return torch.stack([x.mean() for x in data]) def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ @@ -322,7 +399,13 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 - first_ch = 0 if self.include_background else 1 + if self.per_component and (len(y_pred.shape) != 5 or y_pred.shape[1] != 2): + raise ValueError( + f"per_component requires 5D binary segmentation with 2 channels (background + foreground). " + f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)." + ) + + first_ch = 0 if self.include_background and not self.per_component else 1 data = [] for b in range(y_pred.shape[0]): c_list = [] @@ -330,7 +413,10 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] c_list.append(self.compute_channel(x_pred, x)) + if self.per_component: + c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] data.append(torch.stack(c_list)) + data = torch.stack(data, dim=0).contiguous() # type: ignore f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index 04c81ff9a7..bc5ce2c981 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -18,6 +18,7 @@ from parameterized import parameterized from monai.metrics import DiceHelper, DiceMetric, compute_dice +from monai.metrics.fid import FIDMetric _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background @@ -250,6 +251,31 @@ {"label_1": 0.4000, "label_2": 0.6667}, ] +TEST_CASE_16 = [ + {"per_component": True}, + { + "y": ( + lambda: ( + y := torch.zeros((5, 2, 64, 64, 64)), + y.__setitem__((0, 1, slice(20, 25), slice(20, 25), slice(20, 25)), 1), + y.__setitem__((0, 1, slice(40, 45), slice(40, 45), slice(40, 45)), 1), + y.__setitem__((0, 0), 1 - y[0, 1]), + y, + )[-1] + )(), + "y_pred": ( + lambda: ( + y_hat := torch.zeros((5, 2, 64, 64, 64)), + y_hat.__setitem__((0, 1, slice(21, 26), slice(21, 26), slice(21, 26)), 1), + y_hat.__setitem__((0, 1, slice(41, 46), slice(39, 44), slice(41, 46)), 1), + y_hat.__setitem__((0, 0), 1 - y_hat[0, 1]), + y_hat, + )[-1] + )(), + }, + [[[0.5120]], [[1.0]], [[1.0]], [[1.0]], [[1.0]]], +] + class TestComputeMeanDice(unittest.TestCase): @@ -301,6 +327,18 @@ def test_nans_class(self, params, input_data, expected_value): else: np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + # CC DiceMetric tests + @parameterized.expand([TEST_CASE_16]) + def test_cc_dice_value(self, params, input_data, expected_value): + dice_metric = DiceMetric(**params) + dice_metric(**input_data) + result = dice_metric.aggregate(reduction="none") + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + def test_input_dimensions(self): + with self.assertRaises(ValueError): + DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + if __name__ == "__main__": unittest.main() From ccca77abe1d768ab73cb868502cbf6de33e65351 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Fri, 13 Mar 2026 11:05:20 -0400 Subject: [PATCH 2/8] Adding per_component information to inline docstring Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/meandice.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index d61d07475c..6352cfc10d 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -99,6 +99,9 @@ class DiceMetric(CumulativeIterationMetric): If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, the index begins at "0", otherwise at "1". It can also take a list of label names. The outcome will then be returned as a dictionary. + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary + segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. """ @@ -200,6 +203,9 @@ def compute_dice( num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary + segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. Returns: Dice scores per batch and per class, (shape: [batch_size, num_classes]). @@ -255,6 +261,9 @@ class DiceHelper: num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary + segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. """ @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") From c110e2abff1f70572f9a9e6135e815c9dfe6b7ed Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Fri, 13 Mar 2026 11:09:34 -0400 Subject: [PATCH 3/8] fixing indentation and formatting Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/meandice.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 6352cfc10d..edc6f60f04 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -99,8 +99,8 @@ class DiceMetric(CumulativeIterationMetric): If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, the index begins at "0", otherwise at "1". It can also take a list of label names. The outcome will then be returned as a dictionary. - per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be - computed for each connected component in the ground truth, and then averaged. This requires 5D binary + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. """ @@ -203,9 +203,9 @@ def compute_dice( num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. - per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be - computed for each connected component in the ground truth, and then averaged. This requires 5D binary - segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary + segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. Returns: Dice scores per batch and per class, (shape: [batch_size, num_classes]). @@ -261,8 +261,8 @@ class DiceHelper: num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. - per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be - computed for each connected component in the ground truth, and then averaged. This requires 5D binary + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. """ From 41e52c1a6f267ca1d44598259a1c357fbe12d91f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:10:48 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/metrics/meandice.py | 1 - tests/metrics/test_compute_meandice.py | 1 - 2 files changed, 2 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index edc6f60f04..11f4e67a54 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -13,7 +13,6 @@ import numpy as np import torch -import torch.nn.functional as F from scipy.ndimage import distance_transform_edt, generate_binary_structure from scipy.ndimage import label as sn_label diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index bc5ce2c981..119eb255af 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -18,7 +18,6 @@ from parameterized import parameterized from monai.metrics import DiceHelper, DiceMetric, compute_dice -from monai.metrics.fid import FIDMetric _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background From 34a68174b86b87f424e75679ca91db8c124bd3f5 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Fri, 13 Mar 2026 11:57:10 -0400 Subject: [PATCH 5/8] Adding optional import for scipy and fixing issues raised by coderabbitai - docstring issues, ignore_empty bug Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/meandice.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index edc6f60f04..ea0e07d0a7 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -14,14 +14,19 @@ import numpy as np import torch import torch.nn.functional as F + from scipy.ndimage import distance_transform_edt, generate_binary_structure from scipy.ndimage import label as sn_label from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction, deprecated_arg +from monai.utils.module import optional_import from .metric import CumulativeIterationMetric +distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt") +sn_label, _ = optional_import("scipy.ndimage", name="label") + __all__ = ["DiceMetric", "compute_dice", "DiceHelper"] @@ -304,11 +309,15 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): Returns the ID of the nearest component for each voxel. Args: - labels: input label map as a numpy array, where values > 0 are considered seeds for connected components. - connectivity: 6/18/26 (3D) - sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt) - """ + labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. + connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26. + sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances. + Returns: + torch.Tensor: Voronoi region IDs (int32) on CPU. + """ + if not has_ndimage: + raise RuntimeError("scipy.ndimage is required for per_component Dice computation.") x = np.asarray(labels) conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) structure = generate_binary_structure(rank=3, connectivity=conn_rank) @@ -323,12 +332,14 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ - Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately - for each batch item and for each channel of those items. + Compute per-component Dice for a single batch item. Args: - y_pred: input predictions with shape HW[D]. - y: ground truth with shape HW[D]. + y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W). + y (torch.Tensor): Ground truth with shape (1, 2, D, H, W). + + Returns: + torch.Tensor: Mean Dice over connected components. """ data = [] if y_pred.ndim == y.ndim: @@ -338,7 +349,9 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y_pred_idx = y_pred y_idx = y if y_idx[0].sum() == 0: - if y_pred_idx.sum() == 0: + if self.ignore_empty: + data.append(torch.tensor(float("nan"), device=y_idx.device)) + elif y_pred_idx.sum() == 0: data.append(torch.tensor(1.0, device=y_idx.device)) else: data.append(torch.tensor(0.0, device=y_idx.device)) From cb433a83b0d4a2b779e10aada9831773af8f6982 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Fri, 13 Mar 2026 14:38:18 -0400 Subject: [PATCH 6/8] Adding optional import for scipy and fixing issues raised by coderabbitai - docstring issues, ignore_empty bug Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/meandice.py | 33 +++++++++++++++++--------- tests/metrics/test_compute_meandice.py | 2 +- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 11f4e67a54..452cdd1f5a 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -13,14 +13,17 @@ import numpy as np import torch -from scipy.ndimage import distance_transform_edt, generate_binary_structure -from scipy.ndimage import label as sn_label from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction, deprecated_arg +from monai.utils.module import optional_import from .metric import CumulativeIterationMetric +distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt") +generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure") +sn_label, _ = optional_import("scipy.ndimage", name="label") + __all__ = ["DiceMetric", "compute_dice", "DiceHelper"] @@ -303,11 +306,15 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): Returns the ID of the nearest component for each voxel. Args: - labels: input label map as a numpy array, where values > 0 are considered seeds for connected components. - connectivity: 6/18/26 (3D) - sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt) - """ + labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. + connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26. + sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances. + Returns: + torch.Tensor: Voronoi region IDs (int32) on CPU. + """ + if not has_ndimage: + raise RuntimeError("scipy.ndimage is required for per_component Dice computation.") x = np.asarray(labels) conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) structure = generate_binary_structure(rank=3, connectivity=conn_rank) @@ -322,12 +329,14 @@ def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ - Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately - for each batch item and for each channel of those items. + Compute per-component Dice for a single batch item. Args: - y_pred: input predictions with shape HW[D]. - y: ground truth with shape HW[D]. + y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W). + y (torch.Tensor): Ground truth with shape (1, 2, D, H, W). + + Returns: + torch.Tensor: Mean Dice over connected components. """ data = [] if y_pred.ndim == y.ndim: @@ -337,7 +346,9 @@ def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y_pred_idx = y_pred y_idx = y if y_idx[0].sum() == 0: - if y_pred_idx.sum() == 0: + if self.ignore_empty: + data.append(torch.tensor(float("nan"), device=y_idx.device)) + elif y_pred_idx.sum() == 0: data.append(torch.tensor(1.0, device=y_idx.device)) else: data.append(torch.tensor(0.0, device=y_idx.device)) diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index 119eb255af..0df2f15f44 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -251,7 +251,7 @@ ] TEST_CASE_16 = [ - {"per_component": True}, + {"per_component": True, "ignore_empty": False}, { "y": ( lambda: ( From ba2e0b314aa8ca411b1db988a18722a20b3f3043 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Fri, 13 Mar 2026 15:01:35 -0400 Subject: [PATCH 7/8] Adding unittest skipUnless for scipy.ndimage and resolving mypy bug Signed-off-by: Vijay Vignesh Prasad Rao --- tests/metrics/test_compute_meandice.py | 37 ++++++++++++-------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index 0df2f15f44..9f762a763d 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -18,6 +18,9 @@ from parameterized import parameterized from monai.metrics import DiceHelper, DiceMetric, compute_dice +from monai.utils.module import optional_import + +_, has_ndimage = optional_import("scipy.ndimage") _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background @@ -250,32 +253,26 @@ {"label_1": 0.4000, "label_2": 0.6667}, ] +# Testcase for per_component DiceMetric +y = torch.zeros((5, 2, 64, 64, 64)) +y_hat = torch.zeros((5, 2, 64, 64, 64)) + +y[0, 1, 20:25, 20:25, 20:25] = 1 +y[0, 1, 40:45, 40:45, 40:45] = 1 +y[0, 0] = 1 - y[0, 1] + +y_hat[0, 1, 21:26, 21:26, 21:26] = 1 +y_hat[0, 1, 41:46, 39:44, 41:46] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] + TEST_CASE_16 = [ {"per_component": True, "ignore_empty": False}, - { - "y": ( - lambda: ( - y := torch.zeros((5, 2, 64, 64, 64)), - y.__setitem__((0, 1, slice(20, 25), slice(20, 25), slice(20, 25)), 1), - y.__setitem__((0, 1, slice(40, 45), slice(40, 45), slice(40, 45)), 1), - y.__setitem__((0, 0), 1 - y[0, 1]), - y, - )[-1] - )(), - "y_pred": ( - lambda: ( - y_hat := torch.zeros((5, 2, 64, 64, 64)), - y_hat.__setitem__((0, 1, slice(21, 26), slice(21, 26), slice(21, 26)), 1), - y_hat.__setitem__((0, 1, slice(41, 46), slice(39, 44), slice(41, 46)), 1), - y_hat.__setitem__((0, 0), 1 - y_hat[0, 1]), - y_hat, - )[-1] - )(), - }, + {"y": y, "y_pred": y_hat}, [[[0.5120]], [[1.0]], [[1.0]], [[1.0]], [[1.0]]], ] +@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.") class TestComputeMeanDice(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12]) From d9bfb5d24d292b497dbb93c03b07b9925b5dea5e Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Prasad Rao Date: Fri, 13 Mar 2026 16:37:08 -0400 Subject: [PATCH 8/8] Adding unittest skip only to test cc functions and resolving shape check bug Signed-off-by: Vijay Vignesh Prasad Rao --- monai/metrics/meandice.py | 11 ++++++----- tests/metrics/test_compute_meandice.py | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 452cdd1f5a..cdb7831a60 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -418,11 +418,12 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 - if self.per_component and (len(y_pred.shape) != 5 or y_pred.shape[1] != 2): - raise ValueError( - f"per_component requires 5D binary segmentation with 2 channels (background + foreground). " - f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)." - ) + if self.per_component: + if len(y_pred.shape) != 5 or len(y.shape) != 5 or y_pred.shape[1] != 2 or y.shape[1] != 2: + raise ValueError( + "per_component requires both y_pred and y to be 5D binary segmentations " + f"with 2 channels. Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." + ) first_ch = 0 if self.include_background and not self.per_component else 1 data = [] diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index 9f762a763d..9495113614 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -272,7 +272,6 @@ ] -@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.") class TestComputeMeanDice(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12]) @@ -325,12 +324,14 @@ def test_nans_class(self, params, input_data, expected_value): # CC DiceMetric tests @parameterized.expand([TEST_CASE_16]) + @unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.") def test_cc_dice_value(self, params, input_data, expected_value): dice_metric = DiceMetric(**params) dice_metric(**input_data) result = dice_metric.aggregate(reduction="none") np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + @unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.") def test_input_dimensions(self): with self.assertRaises(ValueError): DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))