Skip to content

Add parameter to DiceMetric and DiceHelper classes#8774

Draft
VijayVignesh1 wants to merge 6 commits intoProject-MONAI:devfrom
VijayVignesh1:8733-per-component-dice-metric
Draft

Add parameter to DiceMetric and DiceHelper classes#8774
VijayVignesh1 wants to merge 6 commits intoProject-MONAI:devfrom
VijayVignesh1:8733-per-component-dice-metric

Conversation

@VijayVignesh1
Copy link

Fixes #8733

Description

A few sentences describing the changes proposed in this pull request.
This PR adds support for connected component-based Dice metric calculation to the existing DiceMetric and DiceHelper classes.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 13, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 114f10c0-c613-4f87-b8cf-77651cdf65a1

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The changes add per-component Dice evaluation capability to MONAI's DiceMetric. A new per_component parameter has been introduced across DiceMetric, DiceHelper, and the compute_dice function. When enabled, the implementation decomposes ground truth into connected components using Voronoi-based labeling and computes Dice scores per component. Two new methods (compute_voronoi_regions_fast and compute_cc_dice) enable component-wise calculations. Input validation requires 5D binary segmentation tensors with exactly 2 channels. Corresponding test cases validate the new functionality and input constraints.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 54.55% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive Title is vague and generic, failing to highlight the key feature of per-component Dice calculation. Use a more specific title like 'Add per_component parameter for connected-component-based Dice metric calculation'.
✅ Passed checks (3 passed)
Check name Status Explanation
Description check ✅ Passed Description covers main changes and includes all required template sections with checkmarks for testing and documentation.
Linked Issues check ✅ Passed Changes implement per_component parameter for Dice metric with Voronoi-based component labeling, validation logic, and per-component aggregation as required [#8733].
Out of Scope Changes check ✅ Passed All changes focus on DiceMetric/DiceHelper per_component implementation and tests; no unrelated modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

Flake8 can be used to improve the quality of Python code reviews.

Flake8 is a Python linter that wraps PyFlakes, pycodestyle and Ned Batchelder's McCabe script.

To configure Flake8, add a '.flake8' or 'setup.cfg' file to your project root.

See Flake8 Documentation for more details.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (3)
monai/metrics/meandice.py (1)

418-426: Wasted computation when per_component=True.

Lines 420-423 compute channel Dice, then lines 424-425 discard it and overwrite c_list. Move the branch earlier.

Proposed fix
         for b in range(y_pred.shape[0]):
-            c_list = []
-            for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
-                x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
-                x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
-                c_list.append(self.compute_channel(x_pred, x))
             if self.per_component:
                 c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
+            else:
+                c_list = []
+                for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
+                    x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
+                    x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
+                    c_list.append(self.compute_channel(x_pred, x))
             data.append(torch.stack(c_list))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 418 - 426, The loop is doing wasted
work: it always computes per-channel Dice via compute_channel for each c and
only when self.per_component is True it discards those results and replaces
c_list with a compute_cc_dice call. Change the logic inside the for b in
range(...) loop to check self.per_component before computing channels; if
self.per_component is True, directly set c_list =
[self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] and
skip the per-channel compute_channel loop and related x_pred/x extraction,
otherwise run the existing per-channel path that builds c_list with
compute_channel as before. Ensure references to y_pred, y, compute_channel,
compute_cc_dice, c_list and per_component are used so the branch correctly
short-circuits the expensive channel computations.
tests/metrics/test_compute_meandice.py (2)

253-276: Test data construction is hard to follow; expected value undocumented.

The lambda-walrus pattern obscures setup. Consider a helper function. Also document how 0.5120 was derived for maintainability.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_compute_meandice.py` around lines 253 - 276, TEST_CASE_16
uses a lambda-walrus pattern (variables y and y_pred inside TEST_CASE_16) that
makes the test data setup hard to read and omits explanation of the expected
0.5120 value; extract the tensor construction into a small descriptive helper
(e.g., build_test_case_16_tensors or make_meandice_case_16) and replace the
inline lambdas with calls to that helper, and add a short comment next to the
expected value explaining how 0.5120 was computed (e.g., describe overlapping
voxel counts and Dice formula for the two shifted cubes) so the test is readable
and the expected number is documented.

337-339: Shape mismatch may obscure test intent.

Both tensors are 4D (not 5D) and have 3 channels (not 2). The spatial mismatch (144 vs 145) is irrelevant to the validation. Use matching shapes to clarify:

-            DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))
+            DiceMetric(per_component=True)(torch.ones([3, 3, 64, 64]), torch.ones([3, 3, 64, 64]))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_compute_meandice.py` around lines 337 - 339, The test
currently uses two 4D tensors with mismatched spatial sizes and 3 channels,
which obscures the intent to validate dimensionality; update
test_input_dimensions so both tensors have identical shapes but still 4D to
trigger the ValueError (e.g., use torch.ones([3, 2, 144, 144]) for both),
ensuring the failure comes from incorrect dimensionality for DiceMetric rather
than a spatial-size mismatch or wrong channel count.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/metrics/meandice.py`:
- Around line 14-17: The module imports SciPy unconditionally causing CI
failures when SciPy is not installed; change the top-level imports to use
MONAI's optional_import pattern to import distance_transform_edt,
generate_binary_structure and label (sn_label) and expose a has_scipy flag, then
in compute_voronoi_regions_fast check has_scipy and raise a clear RuntimeError
if False; update references to
sn_label/distance_transform_edt/generate_binary_structure in the file to use the
optionally imported symbols so runtime usage is guarded.
- Line 416: The code currently sets first_ch based on a combined condition which
silently ignores include_background when per_component is True; update the logic
in the MeanDice/meandice implementation to detect the conflicting flags
(self.per_component True and self.include_background False) and emit a clear
warning (e.g., warnings.warn or using the module logger) that include_background
will be ignored when per_component is enabled, then keep the existing behavior
for first_ch (set first_ch=1) to preserve compatibility; reference the
attributes self.include_background, self.per_component and the local variable
first_ch so reviewers can locate and adjust the check and add the warning.
- Around line 300-321: The compute_voronoi_regions_fast function's docstring
lacks a Returns section and the function always returns a CPU tensor
(torch.from_numpy) even if the original input was a CUDA tensor; update the
docstring to include a Returns: description and type (torch.Tensor on same
device as input) and change the implementation to preserve input type/device:
accept numpy array or torch.Tensor for labels, record the original device and
dtype (if torch.Tensor), convert input to CPU numpy for EDT processing, then
convert the resulting voronoi numpy array back to a torch.Tensor and move it to
the original device and appropriate dtype before returning; reference
compute_voronoi_regions_fast, labels, edt_input, indices, and voronoi when
locating where to apply these changes.
- Around line 323-364: The compute_cc_dice method's docstring and
empty-ground-truth handling are incorrect: update the docstring for
compute_cc_dice to state the actual expected input shapes (e.g., tensors that
may include batch and channel dims such as (1, C, D, H, W) or
per-channel/per-item spatial tensors) and then change the empty-GT branch (the
y_idx[0].sum() == 0 case) to consult self.ignore_empty (return
torch.tensor(0.0/1.0 or skip/ignore according to class semantics) instead of
always appending 1.0/0.0), and move the inf/nan replacement logic (the two
torch.where lines that sanitize values) out of the else block so they run for
both empty and non-empty cases; refer to symbols compute_cc_dice, y_idx,
y_pred_idx, self.ignore_empty, cc_assignment, uniq/inv/hist/dice_scores to
locate and update the logic and docstring.

---

Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 418-426: The loop is doing wasted work: it always computes
per-channel Dice via compute_channel for each c and only when self.per_component
is True it discards those results and replaces c_list with a compute_cc_dice
call. Change the logic inside the for b in range(...) loop to check
self.per_component before computing channels; if self.per_component is True,
directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0),
y=y[b].unsqueeze(0))] and skip the per-channel compute_channel loop and related
x_pred/x extraction, otherwise run the existing per-channel path that builds
c_list with compute_channel as before. Ensure references to y_pred, y,
compute_channel, compute_cc_dice, c_list and per_component are used so the
branch correctly short-circuits the expensive channel computations.

In `@tests/metrics/test_compute_meandice.py`:
- Around line 253-276: TEST_CASE_16 uses a lambda-walrus pattern (variables y
and y_pred inside TEST_CASE_16) that makes the test data setup hard to read and
omits explanation of the expected 0.5120 value; extract the tensor construction
into a small descriptive helper (e.g., build_test_case_16_tensors or
make_meandice_case_16) and replace the inline lambdas with calls to that helper,
and add a short comment next to the expected value explaining how 0.5120 was
computed (e.g., describe overlapping voxel counts and Dice formula for the two
shifted cubes) so the test is readable and the expected number is documented.
- Around line 337-339: The test currently uses two 4D tensors with mismatched
spatial sizes and 3 channels, which obscures the intent to validate
dimensionality; update test_input_dimensions so both tensors have identical
shapes but still 4D to trigger the ValueError (e.g., use torch.ones([3, 2, 144,
144]) for both), ensuring the failure comes from incorrect dimensionality for
DiceMetric rather than a spatial-size mismatch or wrong channel count.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f8c5c98e-0bb3-413d-9471-3bef41a45cfa

📥 Commits

Reviewing files that changed from the base of the PR and between daaedaa and 41e52c1.

📒 Files selected for processing (2)
  • monai/metrics/meandice.py
  • tests/metrics/test_compute_meandice.py

Comment on lines +14 to +17
import numpy as np
import torch
from scipy.ndimage import distance_transform_edt, generate_binary_structure
from scipy.ndimage import label as sn_label
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

SciPy is optional but imported unconditionally—breaks CI.

Per setup.cfg, scipy is in [options.extras_require] not install_requires. Pipeline failure confirms: ModuleNotFoundError: No module named 'scipy'.

Use MONAI's optional_import pattern and guard usage at runtime.

Proposed fix
-import numpy as np
 import torch
-from scipy.ndimage import distance_transform_edt, generate_binary_structure
-from scipy.ndimage import label as sn_label
+import numpy as np
+
+from monai.utils.module import optional_import
+
+distance_transform_edt, has_scipy = optional_import("scipy.ndimage", name="distance_transform_edt")
+generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure")
+sn_label, _ = optional_import("scipy.ndimage", name="label")

Then guard in compute_voronoi_regions_fast:

if not has_scipy:
    raise RuntimeError("scipy is required for per_component Dice computation.")
🧰 Tools
🪛 GitHub Actions: premerge-min

[error] 16-16: ModuleNotFoundError: No module named 'scipy' while importing MONAI; ensure SciPy is installed in the CI environment.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 14 - 17, The module imports SciPy
unconditionally causing CI failures when SciPy is not installed; change the
top-level imports to use MONAI's optional_import pattern to import
distance_transform_edt, generate_binary_structure and label (sn_label) and
expose a has_scipy flag, then in compute_voronoi_regions_fast check has_scipy
and raise a clear RuntimeError if False; update references to
sn_label/distance_transform_edt/generate_binary_structure in the file to use the
optionally imported symbols so runtime usage is guarded.

Comment on lines +300 to +321
def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.
Args:
labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
connectivity: 6/18/26 (3D)
sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
"""

x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring incomplete; return always on CPU.

  1. Missing Returns: section and type annotations per coding guidelines.
  2. Line 321 returns CPU tensor via torch.from_numpy() regardless of input device.
Docstring fix
     def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
         """
         Voronoi assignment to connected components (CPU, single EDT) without cc3d.
-        Returns the ID of the nearest component for each voxel.
 
         Args:
-            labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
-            connectivity: 6/18/26 (3D)
-            sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
+            labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
+            connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
+            sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
+
+        Returns:
+            torch.Tensor: Voronoi region IDs (int32) on CPU.
         """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.
Args:
labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
connectivity: 6/18/26 (3D)
sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
"""
x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)
def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Args:
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
Returns:
torch.Tensor: Voronoi region IDs (int32) on CPU.
"""
x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 300 - 321, The
compute_voronoi_regions_fast function's docstring lacks a Returns section and
the function always returns a CPU tensor (torch.from_numpy) even if the original
input was a CUDA tensor; update the docstring to include a Returns: description
and type (torch.Tensor on same device as input) and change the implementation to
preserve input type/device: accept numpy array or torch.Tensor for labels,
record the original device and dtype (if torch.Tensor), convert input to CPU
numpy for EDT processing, then convert the resulting voronoi numpy array back to
a torch.Tensor and move it to the original device and appropriate dtype before
returning; reference compute_voronoi_regions_fast, labels, edt_input, indices,
and voronoi when locating where to apply these changes.

Comment on lines +323 to +364
def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
for each batch item and for each channel of those items.
Args:
y_pred: input predictions with shape HW[D].
y: ground truth with shape HW[D].
"""
data = []
if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
if y_pred_idx.sum() == 0:
data.append(torch.tensor(1.0, device=y_idx.device))
else:
data.append(torch.tensor(0.0, device=y_idx.device))
else:
cc_assignment = self.compute_voronoi_regions_fast(y_idx[0])
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
nof_components = uniq.numel()
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
idx = (inv << 2) | code
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
denom = 2 * tp + fp + fn
dice_scores = torch.where(
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
)
data.append(dice_scores.unsqueeze(-1))
data = [
torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [
torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [x.reshape(-1, 1) for x in data]
return torch.stack([x.mean() for x in data])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Docstring states wrong shape; ignore_empty flag ignored.

  1. Docstring says HW[D] but method receives (1, C, D, H, W) from line 425.
  2. self.ignore_empty is not consulted—empty GT always returns 1.0 or 0.0 (lines 340-343), contradicting class behavior.
  3. Lines 357-362 (inf/nan replacement) only run in else branch; unreachable for empty GT case.
Fix docstring and respect ignore_empty
     def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
         """
-        Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
-        for each batch item and for each channel of those items.
+        Compute per-component Dice for a single batch item.
 
         Args:
-            y_pred: input predictions with shape HW[D].
-            y: ground truth with shape HW[D].
+            y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
+            y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).
+
+        Returns:
+            torch.Tensor: Mean Dice over connected components.
         """
         data = []
         if y_pred.ndim == y.ndim:
             y_pred_idx = torch.argmax(y_pred, dim=1)
             y_idx = torch.argmax(y, dim=1)
         else:
             y_pred_idx = y_pred
             y_idx = y
         if y_idx[0].sum() == 0:
-            if y_pred_idx.sum() == 0:
-                data.append(torch.tensor(1.0, device=y_idx.device))
-            else:
-                data.append(torch.tensor(0.0, device=y_idx.device))
+            if self.ignore_empty:
+                data.append(torch.tensor(float("nan"), device=y_idx.device))
+            elif y_pred_idx.sum() == 0:
+                data.append(torch.tensor(1.0, device=y_idx.device))
+            else:
+                data.append(torch.tensor(0.0, device=y_idx.device))
         else:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 323 - 364, The compute_cc_dice
method's docstring and empty-ground-truth handling are incorrect: update the
docstring for compute_cc_dice to state the actual expected input shapes (e.g.,
tensors that may include batch and channel dims such as (1, C, D, H, W) or
per-channel/per-item spatial tensors) and then change the empty-GT branch (the
y_idx[0].sum() == 0 case) to consult self.ignore_empty (return
torch.tensor(0.0/1.0 or skip/ignore according to class semantics) instead of
always appending 1.0/0.0), and move the inf/nan replacement logic (the two
torch.where lines that sanitize values) out of the else block so they run for
both empty and non-empty cases; refer to symbols compute_cc_dice, y_idx,
y_pred_idx, self.ignore_empty, cc_assignment, uniq/inv/hist/dice_scores to
locate and update the logic and docstring.

f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)."
)

first_ch = 0 if self.include_background and not self.per_component else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

include_background silently ignored when per_component=True.

Line forces first_ch=1 regardless of include_background. Document this or raise a warning if both flags conflict.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` at line 416, The code currently sets first_ch
based on a combined condition which silently ignores include_background when
per_component is True; update the logic in the MeanDice/meandice implementation
to detect the conflicting flags (self.per_component True and
self.include_background False) and emit a clear warning (e.g., warnings.warn or
using the module logger) that include_background will be ignored when
per_component is enabled, then keep the existing behavior for first_ch (set
first_ch=1) to preserve compatibility; reference the attributes
self.include_background, self.per_component and the local variable first_ch so
reviewers can locate and adjust the check and add the warning.

@VijayVignesh1 VijayVignesh1 marked this pull request as draft March 13, 2026 15:40
…itai - docstring issues, ignore_empty bug

Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Evaluation of Semantic Segmentation Metrics on a per-component basis

1 participant