Skip to content
107 changes: 106 additions & 1 deletion monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@

from __future__ import annotations

import numpy as np
import torch

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction, deprecated_arg
from monai.utils.module import optional_import

from .metric import CumulativeIterationMetric

distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt")
generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure")
sn_label, _ = optional_import("scipy.ndimage", name="label")

__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]


Expand Down Expand Up @@ -95,6 +101,9 @@ class DiceMetric(CumulativeIterationMetric):
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

"""

Expand All @@ -106,6 +115,7 @@ def __init__(
ignore_empty: bool = True,
num_classes: int | None = None,
return_with_label: bool | list[str] = False,
per_component: bool = False,
) -> None:
super().__init__()
self.include_background = include_background
Expand All @@ -114,13 +124,15 @@ def __init__(
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.return_with_label = return_with_label
self.per_component = per_component
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
apply_argmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
per_component=self.per_component,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
Expand Down Expand Up @@ -175,6 +187,7 @@ def compute_dice(
include_background: bool = True,
ignore_empty: bool = True,
num_classes: int | None = None,
per_component: bool = False,
) -> torch.Tensor:
"""
Computes Dice score metric for a batch of predictions. This performs the same computation as
Expand All @@ -192,6 +205,9 @@ def compute_dice(
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

Returns:
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
Expand All @@ -204,6 +220,7 @@ def compute_dice(
apply_argmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
per_component=per_component,
)(y_pred=y_pred, y=y)


Expand Down Expand Up @@ -246,6 +263,9 @@ class DiceHelper:
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
"""

@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
Expand All @@ -262,6 +282,7 @@ def __init__(
num_classes: int | None = None,
sigmoid: bool | None = None,
softmax: bool | None = None,
per_component: bool = False,
) -> None:
# handling deprecated arguments
if sigmoid is not None:
Expand All @@ -277,6 +298,81 @@ def __init__(
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.per_component = per_component

def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.

Args:
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.

Returns:
torch.Tensor: Voronoi region IDs (int32) on CPU.
"""
if not has_ndimage:
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)
Comment on lines +303 to +328
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring incomplete; return always on CPU.

  1. Missing Returns: section and type annotations per coding guidelines.
  2. Line 321 returns CPU tensor via torch.from_numpy() regardless of input device.
Docstring fix
     def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
         """
         Voronoi assignment to connected components (CPU, single EDT) without cc3d.
-        Returns the ID of the nearest component for each voxel.
 
         Args:
-            labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
-            connectivity: 6/18/26 (3D)
-            sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
+            labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
+            connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
+            sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
+
+        Returns:
+            torch.Tensor: Voronoi region IDs (int32) on CPU.
         """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.
Args:
labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
connectivity: 6/18/26 (3D)
sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
"""
x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)
def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Args:
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
Returns:
torch.Tensor: Voronoi region IDs (int32) on CPU.
"""
x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 300 - 321, The
compute_voronoi_regions_fast function's docstring lacks a Returns section and
the function always returns a CPU tensor (torch.from_numpy) even if the original
input was a CUDA tensor; update the docstring to include a Returns: description
and type (torch.Tensor on same device as input) and change the implementation to
preserve input type/device: accept numpy array or torch.Tensor for labels,
record the original device and dtype (if torch.Tensor), convert input to CPU
numpy for EDT processing, then convert the resulting voronoi numpy array back to
a torch.Tensor and move it to the original device and appropriate dtype before
returning; reference compute_voronoi_regions_fast, labels, edt_input, indices,
and voronoi when locating where to apply these changes.


def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Compute per-component Dice for a single batch item.

Args:
y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).

Returns:
torch.Tensor: Mean Dice over connected components.
"""
data = []
if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
if self.ignore_empty:
data.append(torch.tensor(float("nan"), device=y_idx.device))
elif y_pred_idx.sum() == 0:
data.append(torch.tensor(1.0, device=y_idx.device))
else:
data.append(torch.tensor(0.0, device=y_idx.device))
else:
cc_assignment = self.compute_voronoi_regions_fast(y_idx[0])
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
nof_components = uniq.numel()
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
idx = (inv << 2) | code
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
denom = 2 * tp + fp + fn
dice_scores = torch.where(
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
)
data.append(dice_scores.unsqueeze(-1))
data = [
torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [
torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [x.reshape(-1, 1) for x in data]
return torch.stack([x.mean() for x in data])
Comment on lines +330 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Docstring states wrong shape; ignore_empty flag ignored.

  1. Docstring says HW[D] but method receives (1, C, D, H, W) from line 425.
  2. self.ignore_empty is not consulted—empty GT always returns 1.0 or 0.0 (lines 340-343), contradicting class behavior.
  3. Lines 357-362 (inf/nan replacement) only run in else branch; unreachable for empty GT case.
Fix docstring and respect ignore_empty
     def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
         """
-        Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
-        for each batch item and for each channel of those items.
+        Compute per-component Dice for a single batch item.
 
         Args:
-            y_pred: input predictions with shape HW[D].
-            y: ground truth with shape HW[D].
+            y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
+            y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).
+
+        Returns:
+            torch.Tensor: Mean Dice over connected components.
         """
         data = []
         if y_pred.ndim == y.ndim:
             y_pred_idx = torch.argmax(y_pred, dim=1)
             y_idx = torch.argmax(y, dim=1)
         else:
             y_pred_idx = y_pred
             y_idx = y
         if y_idx[0].sum() == 0:
-            if y_pred_idx.sum() == 0:
-                data.append(torch.tensor(1.0, device=y_idx.device))
-            else:
-                data.append(torch.tensor(0.0, device=y_idx.device))
+            if self.ignore_empty:
+                data.append(torch.tensor(float("nan"), device=y_idx.device))
+            elif y_pred_idx.sum() == 0:
+                data.append(torch.tensor(1.0, device=y_idx.device))
+            else:
+                data.append(torch.tensor(0.0, device=y_idx.device))
         else:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 323 - 364, The compute_cc_dice
method's docstring and empty-ground-truth handling are incorrect: update the
docstring for compute_cc_dice to state the actual expected input shapes (e.g.,
tensors that may include batch and channel dims such as (1, C, D, H, W) or
per-channel/per-item spatial tensors) and then change the empty-GT branch (the
y_idx[0].sum() == 0 case) to consult self.ignore_empty (return
torch.tensor(0.0/1.0 or skip/ignore according to class semantics) instead of
always appending 1.0/0.0), and move the inf/nan replacement logic (the two
torch.where lines that sanitize values) out of the else block so they run for
both empty and non-empty cases; refer to symbols compute_cc_dice, y_idx,
y_pred_idx, self.ignore_empty, cc_assignment, uniq/inv/hist/dice_scores to
locate and update the logic and docstring.


def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -322,15 +418,24 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5

first_ch = 0 if self.include_background else 1
if self.per_component and (len(y_pred.shape) != 5 or y_pred.shape[1] != 2):
raise ValueError(
f"per_component requires 5D binary segmentation with 2 channels (background + foreground). "
f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)."
)

first_ch = 0 if self.include_background and not self.per_component else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

include_background silently ignored when per_component=True.

Line forces first_ch=1 regardless of include_background. Document this or raise a warning if both flags conflict.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` at line 416, The code currently sets first_ch
based on a combined condition which silently ignores include_background when
per_component is True; update the logic in the MeanDice/meandice implementation
to detect the conflicting flags (self.per_component True and
self.include_background False) and emit a clear warning (e.g., warnings.warn or
using the module logger) that include_background will be ignored when
per_component is enabled, then keep the existing behavior for first_ch (set
first_ch=1) to preserve compatibility; reference the attributes
self.include_background, self.per_component and the local variable first_ch so
reviewers can locate and adjust the check and add the warning.

data = []
for b in range(y_pred.shape[0]):
c_list = []
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
c_list.append(self.compute_channel(x_pred, x))
if self.per_component:
c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
data.append(torch.stack(c_list))

data = torch.stack(data, dim=0).contiguous() # type: ignore

f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
Expand Down
34 changes: 34 additions & 0 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from parameterized import parameterized

from monai.metrics import DiceHelper, DiceMetric, compute_dice
from monai.utils.module import optional_import

_, has_ndimage = optional_import("scipy.ndimage")

_device = "cuda:0" if torch.cuda.is_available() else "cpu"
# keep background
Expand Down Expand Up @@ -250,7 +253,26 @@
{"label_1": 0.4000, "label_2": 0.6667},
]

# Testcase for per_component DiceMetric
y = torch.zeros((5, 2, 64, 64, 64))
y_hat = torch.zeros((5, 2, 64, 64, 64))

y[0, 1, 20:25, 20:25, 20:25] = 1
y[0, 1, 40:45, 40:45, 40:45] = 1
y[0, 0] = 1 - y[0, 1]

y_hat[0, 1, 21:26, 21:26, 21:26] = 1
y_hat[0, 1, 41:46, 39:44, 41:46] = 1
y_hat[0, 0] = 1 - y_hat[0, 1]

TEST_CASE_16 = [
{"per_component": True, "ignore_empty": False},
{"y": y, "y_pred": y_hat},
[[[0.5120]], [[1.0]], [[1.0]], [[1.0]], [[1.0]]],
]


@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
class TestComputeMeanDice(unittest.TestCase):

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
Expand Down Expand Up @@ -301,6 +323,18 @@ def test_nans_class(self, params, input_data, expected_value):
else:
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# CC DiceMetric tests
@parameterized.expand([TEST_CASE_16])
def test_cc_dice_value(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
dice_metric(**input_data)
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

def test_input_dimensions(self):
with self.assertRaises(ValueError):
DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))


if __name__ == "__main__":
unittest.main()