diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index fedd94fb93..cdb7831a60 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -11,13 +11,19 @@ from __future__ import annotations +import numpy as np import torch from monai.metrics.utils import do_metric_reduction from monai.utils import MetricReduction, deprecated_arg +from monai.utils.module import optional_import from .metric import CumulativeIterationMetric +distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt") +generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure") +sn_label, _ = optional_import("scipy.ndimage", name="label") + __all__ = ["DiceMetric", "compute_dice", "DiceHelper"] @@ -95,6 +101,9 @@ class DiceMetric(CumulativeIterationMetric): If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, the index begins at "0", otherwise at "1". It can also take a list of label names. The outcome will then be returned as a dictionary. + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary + segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. """ @@ -106,6 +115,7 @@ def __init__( ignore_empty: bool = True, num_classes: int | None = None, return_with_label: bool | list[str] = False, + per_component: bool = False, ) -> None: super().__init__() self.include_background = include_background @@ -114,6 +124,7 @@ def __init__( self.ignore_empty = ignore_empty self.num_classes = num_classes self.return_with_label = return_with_label + self.per_component = per_component self.dice_helper = DiceHelper( include_background=self.include_background, reduction=MetricReduction.NONE, @@ -121,6 +132,7 @@ def __init__( apply_argmax=False, ignore_empty=self.ignore_empty, num_classes=self.num_classes, + per_component=self.per_component, ) def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] @@ -175,6 +187,7 @@ def compute_dice( include_background: bool = True, ignore_empty: bool = True, num_classes: int | None = None, + per_component: bool = False, ) -> torch.Tensor: """ Computes Dice score metric for a batch of predictions. This performs the same computation as @@ -192,6 +205,9 @@ def compute_dice( num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary + segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. Returns: Dice scores per batch and per class, (shape: [batch_size, num_classes]). @@ -204,6 +220,7 @@ def compute_dice( apply_argmax=False, ignore_empty=ignore_empty, num_classes=num_classes, + per_component=per_component, )(y_pred=y_pred, y=y) @@ -246,6 +263,9 @@ class DiceHelper: num_classes: number of input channels (always including the background). When this is ``None``, ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are single-channel class indices and the number of classes is not automatically inferred from data. + per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be + computed for each connected component in the ground truth, and then averaged. This requires 5D binary + segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. """ @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") @@ -262,6 +282,7 @@ def __init__( num_classes: int | None = None, sigmoid: bool | None = None, softmax: bool | None = None, + per_component: bool = False, ) -> None: # handling deprecated arguments if sigmoid is not None: @@ -277,6 +298,81 @@ def __init__( self.activate = activate self.ignore_empty = ignore_empty self.num_classes = num_classes + self.per_component = per_component + + def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): + """ + Voronoi assignment to connected components (CPU, single EDT) without cc3d. + Returns the ID of the nearest component for each voxel. + + Args: + labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. + connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26. + sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances. + + Returns: + torch.Tensor: Voronoi region IDs (int32) on CPU. + """ + if not has_ndimage: + raise RuntimeError("scipy.ndimage is required for per_component Dice computation.") + x = np.asarray(labels) + conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) + structure = generate_binary_structure(rank=3, connectivity=conn_rank) + cc, num = sn_label(x > 0, structure=structure) + if num == 0: + return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) + edt_input = np.ones(cc.shape, dtype=np.uint8) + edt_input[cc > 0] = 0 + indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) + voronoi = cc[tuple(indices)] + return torch.from_numpy(voronoi) + + def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Compute per-component Dice for a single batch item. + + Args: + y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W). + y (torch.Tensor): Ground truth with shape (1, 2, D, H, W). + + Returns: + torch.Tensor: Mean Dice over connected components. + """ + data = [] + if y_pred.ndim == y.ndim: + y_pred_idx = torch.argmax(y_pred, dim=1) + y_idx = torch.argmax(y, dim=1) + else: + y_pred_idx = y_pred + y_idx = y + if y_idx[0].sum() == 0: + if self.ignore_empty: + data.append(torch.tensor(float("nan"), device=y_idx.device)) + elif y_pred_idx.sum() == 0: + data.append(torch.tensor(1.0, device=y_idx.device)) + else: + data.append(torch.tensor(0.0, device=y_idx.device)) + else: + cc_assignment = self.compute_voronoi_regions_fast(y_idx[0]) + uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True) + nof_components = uniq.numel() + code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1) + idx = (inv << 2) | code + hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4) + _, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3] + denom = 2 * tp + fp + fn + dice_scores = torch.where( + denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device) + ) + data.append(dice_scores.unsqueeze(-1)) + data = [ + torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data + ] + data = [ + torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data + ] + data = [x.reshape(-1, 1) for x in data] + return torch.stack([x.mean() for x in data]) def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ @@ -322,7 +418,14 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 - first_ch = 0 if self.include_background else 1 + if self.per_component: + if len(y_pred.shape) != 5 or len(y.shape) != 5 or y_pred.shape[1] != 2 or y.shape[1] != 2: + raise ValueError( + "per_component requires both y_pred and y to be 5D binary segmentations " + f"with 2 channels. Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}." + ) + + first_ch = 0 if self.include_background and not self.per_component else 1 data = [] for b in range(y_pred.shape[0]): c_list = [] @@ -330,7 +433,10 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] c_list.append(self.compute_channel(x_pred, x)) + if self.per_component: + c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] data.append(torch.stack(c_list)) + data = torch.stack(data, dim=0).contiguous() # type: ignore f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index 04c81ff9a7..9495113614 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -18,6 +18,9 @@ from parameterized import parameterized from monai.metrics import DiceHelper, DiceMetric, compute_dice +from monai.utils.module import optional_import + +_, has_ndimage = optional_import("scipy.ndimage") _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background @@ -250,6 +253,24 @@ {"label_1": 0.4000, "label_2": 0.6667}, ] +# Testcase for per_component DiceMetric +y = torch.zeros((5, 2, 64, 64, 64)) +y_hat = torch.zeros((5, 2, 64, 64, 64)) + +y[0, 1, 20:25, 20:25, 20:25] = 1 +y[0, 1, 40:45, 40:45, 40:45] = 1 +y[0, 0] = 1 - y[0, 1] + +y_hat[0, 1, 21:26, 21:26, 21:26] = 1 +y_hat[0, 1, 41:46, 39:44, 41:46] = 1 +y_hat[0, 0] = 1 - y_hat[0, 1] + +TEST_CASE_16 = [ + {"per_component": True, "ignore_empty": False}, + {"y": y, "y_pred": y_hat}, + [[[0.5120]], [[1.0]], [[1.0]], [[1.0]], [[1.0]]], +] + class TestComputeMeanDice(unittest.TestCase): @@ -301,6 +322,20 @@ def test_nans_class(self, params, input_data, expected_value): else: np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + # CC DiceMetric tests + @parameterized.expand([TEST_CASE_16]) + @unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.") + def test_cc_dice_value(self, params, input_data, expected_value): + dice_metric = DiceMetric(**params) + dice_metric(**input_data) + result = dice_metric.aggregate(reduction="none") + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + @unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.") + def test_input_dimensions(self): + with self.assertRaises(ValueError): + DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + if __name__ == "__main__": unittest.main()