Add parameter to DiceMetric and DiceHelper classes#8774
Add parameter to DiceMetric and DiceHelper classes#8774VijayVignesh1 wants to merge 6 commits intoProject-MONAI:devfrom
Conversation
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
for more information, see https://pre-commit.ci
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThe changes add per-component Dice evaluation capability to MONAI's DiceMetric. A new Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip Flake8 can be used to improve the quality of Python code reviews.Flake8 is a Python linter that wraps PyFlakes, pycodestyle and Ned Batchelder's McCabe script. To configure Flake8, add a '.flake8' or 'setup.cfg' file to your project root. See Flake8 Documentation for more details. |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (3)
monai/metrics/meandice.py (1)
418-426: Wasted computation whenper_component=True.Lines 420-423 compute channel Dice, then lines 424-425 discard it and overwrite
c_list. Move the branch earlier.Proposed fix
for b in range(y_pred.shape[0]): - c_list = [] - for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: - x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() - x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] - c_list.append(self.compute_channel(x_pred, x)) if self.per_component: c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] + else: + c_list = [] + for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: + x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() + x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] + c_list.append(self.compute_channel(x_pred, x)) data.append(torch.stack(c_list))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 418 - 426, The loop is doing wasted work: it always computes per-channel Dice via compute_channel for each c and only when self.per_component is True it discards those results and replaces c_list with a compute_cc_dice call. Change the logic inside the for b in range(...) loop to check self.per_component before computing channels; if self.per_component is True, directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] and skip the per-channel compute_channel loop and related x_pred/x extraction, otherwise run the existing per-channel path that builds c_list with compute_channel as before. Ensure references to y_pred, y, compute_channel, compute_cc_dice, c_list and per_component are used so the branch correctly short-circuits the expensive channel computations.tests/metrics/test_compute_meandice.py (2)
253-276: Test data construction is hard to follow; expected value undocumented.The lambda-walrus pattern obscures setup. Consider a helper function. Also document how
0.5120was derived for maintainability.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/metrics/test_compute_meandice.py` around lines 253 - 276, TEST_CASE_16 uses a lambda-walrus pattern (variables y and y_pred inside TEST_CASE_16) that makes the test data setup hard to read and omits explanation of the expected 0.5120 value; extract the tensor construction into a small descriptive helper (e.g., build_test_case_16_tensors or make_meandice_case_16) and replace the inline lambdas with calls to that helper, and add a short comment next to the expected value explaining how 0.5120 was computed (e.g., describe overlapping voxel counts and Dice formula for the two shifted cubes) so the test is readable and the expected number is documented.
337-339: Shape mismatch may obscure test intent.Both tensors are 4D (not 5D) and have 3 channels (not 2). The spatial mismatch (144 vs 145) is irrelevant to the validation. Use matching shapes to clarify:
- DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) + DiceMetric(per_component=True)(torch.ones([3, 3, 64, 64]), torch.ones([3, 3, 64, 64]))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/metrics/test_compute_meandice.py` around lines 337 - 339, The test currently uses two 4D tensors with mismatched spatial sizes and 3 channels, which obscures the intent to validate dimensionality; update test_input_dimensions so both tensors have identical shapes but still 4D to trigger the ValueError (e.g., use torch.ones([3, 2, 144, 144]) for both), ensuring the failure comes from incorrect dimensionality for DiceMetric rather than a spatial-size mismatch or wrong channel count.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/metrics/meandice.py`:
- Around line 14-17: The module imports SciPy unconditionally causing CI
failures when SciPy is not installed; change the top-level imports to use
MONAI's optional_import pattern to import distance_transform_edt,
generate_binary_structure and label (sn_label) and expose a has_scipy flag, then
in compute_voronoi_regions_fast check has_scipy and raise a clear RuntimeError
if False; update references to
sn_label/distance_transform_edt/generate_binary_structure in the file to use the
optionally imported symbols so runtime usage is guarded.
- Line 416: The code currently sets first_ch based on a combined condition which
silently ignores include_background when per_component is True; update the logic
in the MeanDice/meandice implementation to detect the conflicting flags
(self.per_component True and self.include_background False) and emit a clear
warning (e.g., warnings.warn or using the module logger) that include_background
will be ignored when per_component is enabled, then keep the existing behavior
for first_ch (set first_ch=1) to preserve compatibility; reference the
attributes self.include_background, self.per_component and the local variable
first_ch so reviewers can locate and adjust the check and add the warning.
- Around line 300-321: The compute_voronoi_regions_fast function's docstring
lacks a Returns section and the function always returns a CPU tensor
(torch.from_numpy) even if the original input was a CUDA tensor; update the
docstring to include a Returns: description and type (torch.Tensor on same
device as input) and change the implementation to preserve input type/device:
accept numpy array or torch.Tensor for labels, record the original device and
dtype (if torch.Tensor), convert input to CPU numpy for EDT processing, then
convert the resulting voronoi numpy array back to a torch.Tensor and move it to
the original device and appropriate dtype before returning; reference
compute_voronoi_regions_fast, labels, edt_input, indices, and voronoi when
locating where to apply these changes.
- Around line 323-364: The compute_cc_dice method's docstring and
empty-ground-truth handling are incorrect: update the docstring for
compute_cc_dice to state the actual expected input shapes (e.g., tensors that
may include batch and channel dims such as (1, C, D, H, W) or
per-channel/per-item spatial tensors) and then change the empty-GT branch (the
y_idx[0].sum() == 0 case) to consult self.ignore_empty (return
torch.tensor(0.0/1.0 or skip/ignore according to class semantics) instead of
always appending 1.0/0.0), and move the inf/nan replacement logic (the two
torch.where lines that sanitize values) out of the else block so they run for
both empty and non-empty cases; refer to symbols compute_cc_dice, y_idx,
y_pred_idx, self.ignore_empty, cc_assignment, uniq/inv/hist/dice_scores to
locate and update the logic and docstring.
---
Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 418-426: The loop is doing wasted work: it always computes
per-channel Dice via compute_channel for each c and only when self.per_component
is True it discards those results and replaces c_list with a compute_cc_dice
call. Change the logic inside the for b in range(...) loop to check
self.per_component before computing channels; if self.per_component is True,
directly set c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0),
y=y[b].unsqueeze(0))] and skip the per-channel compute_channel loop and related
x_pred/x extraction, otherwise run the existing per-channel path that builds
c_list with compute_channel as before. Ensure references to y_pred, y,
compute_channel, compute_cc_dice, c_list and per_component are used so the
branch correctly short-circuits the expensive channel computations.
In `@tests/metrics/test_compute_meandice.py`:
- Around line 253-276: TEST_CASE_16 uses a lambda-walrus pattern (variables y
and y_pred inside TEST_CASE_16) that makes the test data setup hard to read and
omits explanation of the expected 0.5120 value; extract the tensor construction
into a small descriptive helper (e.g., build_test_case_16_tensors or
make_meandice_case_16) and replace the inline lambdas with calls to that helper,
and add a short comment next to the expected value explaining how 0.5120 was
computed (e.g., describe overlapping voxel counts and Dice formula for the two
shifted cubes) so the test is readable and the expected number is documented.
- Around line 337-339: The test currently uses two 4D tensors with mismatched
spatial sizes and 3 channels, which obscures the intent to validate
dimensionality; update test_input_dimensions so both tensors have identical
shapes but still 4D to trigger the ValueError (e.g., use torch.ones([3, 2, 144,
144]) for both), ensuring the failure comes from incorrect dimensionality for
DiceMetric rather than a spatial-size mismatch or wrong channel count.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f8c5c98e-0bb3-413d-9471-3bef41a45cfa
📒 Files selected for processing (2)
monai/metrics/meandice.pytests/metrics/test_compute_meandice.py
| import numpy as np | ||
| import torch | ||
| from scipy.ndimage import distance_transform_edt, generate_binary_structure | ||
| from scipy.ndimage import label as sn_label |
There was a problem hiding this comment.
SciPy is optional but imported unconditionally—breaks CI.
Per setup.cfg, scipy is in [options.extras_require] not install_requires. Pipeline failure confirms: ModuleNotFoundError: No module named 'scipy'.
Use MONAI's optional_import pattern and guard usage at runtime.
Proposed fix
-import numpy as np
import torch
-from scipy.ndimage import distance_transform_edt, generate_binary_structure
-from scipy.ndimage import label as sn_label
+import numpy as np
+
+from monai.utils.module import optional_import
+
+distance_transform_edt, has_scipy = optional_import("scipy.ndimage", name="distance_transform_edt")
+generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure")
+sn_label, _ = optional_import("scipy.ndimage", name="label")Then guard in compute_voronoi_regions_fast:
if not has_scipy:
raise RuntimeError("scipy is required for per_component Dice computation.")🧰 Tools
🪛 GitHub Actions: premerge-min
[error] 16-16: ModuleNotFoundError: No module named 'scipy' while importing MONAI; ensure SciPy is installed in the CI environment.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/metrics/meandice.py` around lines 14 - 17, The module imports SciPy
unconditionally causing CI failures when SciPy is not installed; change the
top-level imports to use MONAI's optional_import pattern to import
distance_transform_edt, generate_binary_structure and label (sn_label) and
expose a has_scipy flag, then in compute_voronoi_regions_fast check has_scipy
and raise a clear RuntimeError if False; update references to
sn_label/distance_transform_edt/generate_binary_structure in the file to use the
optionally imported symbols so runtime usage is guarded.
| def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): | ||
| """ | ||
| Voronoi assignment to connected components (CPU, single EDT) without cc3d. | ||
| Returns the ID of the nearest component for each voxel. | ||
| Args: | ||
| labels: input label map as a numpy array, where values > 0 are considered seeds for connected components. | ||
| connectivity: 6/18/26 (3D) | ||
| sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt) | ||
| """ | ||
|
|
||
| x = np.asarray(labels) | ||
| conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) | ||
| structure = generate_binary_structure(rank=3, connectivity=conn_rank) | ||
| cc, num = sn_label(x > 0, structure=structure) | ||
| if num == 0: | ||
| return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) | ||
| edt_input = np.ones(cc.shape, dtype=np.uint8) | ||
| edt_input[cc > 0] = 0 | ||
| indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) | ||
| voronoi = cc[tuple(indices)] | ||
| return torch.from_numpy(voronoi) |
There was a problem hiding this comment.
Docstring incomplete; return always on CPU.
- Missing
Returns:section and type annotations per coding guidelines. - Line 321 returns CPU tensor via
torch.from_numpy()regardless of input device.
Docstring fix
def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
- Returns the ID of the nearest component for each voxel.
Args:
- labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
- connectivity: 6/18/26 (3D)
- sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
+ labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
+ connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26.
+ sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances.
+
+ Returns:
+ torch.Tensor: Voronoi region IDs (int32) on CPU.
"""📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): | |
| """ | |
| Voronoi assignment to connected components (CPU, single EDT) without cc3d. | |
| Returns the ID of the nearest component for each voxel. | |
| Args: | |
| labels: input label map as a numpy array, where values > 0 are considered seeds for connected components. | |
| connectivity: 6/18/26 (3D) | |
| sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt) | |
| """ | |
| x = np.asarray(labels) | |
| conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) | |
| structure = generate_binary_structure(rank=3, connectivity=conn_rank) | |
| cc, num = sn_label(x > 0, structure=structure) | |
| if num == 0: | |
| return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) | |
| edt_input = np.ones(cc.shape, dtype=np.uint8) | |
| edt_input[cc > 0] = 0 | |
| indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) | |
| voronoi = cc[tuple(indices)] | |
| return torch.from_numpy(voronoi) | |
| def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): | |
| """ | |
| Voronoi assignment to connected components (CPU, single EDT) without cc3d. | |
| Args: | |
| labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. | |
| connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26. | |
| sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances. | |
| Returns: | |
| torch.Tensor: Voronoi region IDs (int32) on CPU. | |
| """ | |
| x = np.asarray(labels) | |
| conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) | |
| structure = generate_binary_structure(rank=3, connectivity=conn_rank) | |
| cc, num = sn_label(x > 0, structure=structure) | |
| if num == 0: | |
| return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) | |
| edt_input = np.ones(cc.shape, dtype=np.uint8) | |
| edt_input[cc > 0] = 0 | |
| indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) | |
| voronoi = cc[tuple(indices)] | |
| return torch.from_numpy(voronoi) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/metrics/meandice.py` around lines 300 - 321, The
compute_voronoi_regions_fast function's docstring lacks a Returns section and
the function always returns a CPU tensor (torch.from_numpy) even if the original
input was a CUDA tensor; update the docstring to include a Returns: description
and type (torch.Tensor on same device as input) and change the implementation to
preserve input type/device: accept numpy array or torch.Tensor for labels,
record the original device and dtype (if torch.Tensor), convert input to CPU
numpy for EDT processing, then convert the resulting voronoi numpy array back to
a torch.Tensor and move it to the original device and appropriate dtype before
returning; reference compute_voronoi_regions_fast, labels, edt_input, indices,
and voronoi when locating where to apply these changes.
| def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately | ||
| for each batch item and for each channel of those items. | ||
| Args: | ||
| y_pred: input predictions with shape HW[D]. | ||
| y: ground truth with shape HW[D]. | ||
| """ | ||
| data = [] | ||
| if y_pred.ndim == y.ndim: | ||
| y_pred_idx = torch.argmax(y_pred, dim=1) | ||
| y_idx = torch.argmax(y, dim=1) | ||
| else: | ||
| y_pred_idx = y_pred | ||
| y_idx = y | ||
| if y_idx[0].sum() == 0: | ||
| if y_pred_idx.sum() == 0: | ||
| data.append(torch.tensor(1.0, device=y_idx.device)) | ||
| else: | ||
| data.append(torch.tensor(0.0, device=y_idx.device)) | ||
| else: | ||
| cc_assignment = self.compute_voronoi_regions_fast(y_idx[0]) | ||
| uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True) | ||
| nof_components = uniq.numel() | ||
| code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1) | ||
| idx = (inv << 2) | code | ||
| hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4) | ||
| _, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3] | ||
| denom = 2 * tp + fp + fn | ||
| dice_scores = torch.where( | ||
| denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device) | ||
| ) | ||
| data.append(dice_scores.unsqueeze(-1)) | ||
| data = [ | ||
| torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
| data = [ | ||
| torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
| data = [x.reshape(-1, 1) for x in data] | ||
| return torch.stack([x.mean() for x in data]) |
There was a problem hiding this comment.
Docstring states wrong shape; ignore_empty flag ignored.
- Docstring says
HW[D]but method receives(1, C, D, H, W)from line 425. self.ignore_emptyis not consulted—empty GT always returns 1.0 or 0.0 (lines 340-343), contradicting class behavior.- Lines 357-362 (inf/nan replacement) only run in
elsebranch; unreachable for empty GT case.
Fix docstring and respect ignore_empty
def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
- Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
- for each batch item and for each channel of those items.
+ Compute per-component Dice for a single batch item.
Args:
- y_pred: input predictions with shape HW[D].
- y: ground truth with shape HW[D].
+ y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
+ y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).
+
+ Returns:
+ torch.Tensor: Mean Dice over connected components.
"""
data = []
if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
- if y_pred_idx.sum() == 0:
- data.append(torch.tensor(1.0, device=y_idx.device))
- else:
- data.append(torch.tensor(0.0, device=y_idx.device))
+ if self.ignore_empty:
+ data.append(torch.tensor(float("nan"), device=y_idx.device))
+ elif y_pred_idx.sum() == 0:
+ data.append(torch.tensor(1.0, device=y_idx.device))
+ else:
+ data.append(torch.tensor(0.0, device=y_idx.device))
else:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/metrics/meandice.py` around lines 323 - 364, The compute_cc_dice
method's docstring and empty-ground-truth handling are incorrect: update the
docstring for compute_cc_dice to state the actual expected input shapes (e.g.,
tensors that may include batch and channel dims such as (1, C, D, H, W) or
per-channel/per-item spatial tensors) and then change the empty-GT branch (the
y_idx[0].sum() == 0 case) to consult self.ignore_empty (return
torch.tensor(0.0/1.0 or skip/ignore according to class semantics) instead of
always appending 1.0/0.0), and move the inf/nan replacement logic (the two
torch.where lines that sanitize values) out of the else block so they run for
both empty and non-empty cases; refer to symbols compute_cc_dice, y_idx,
y_pred_idx, self.ignore_empty, cc_assignment, uniq/inv/hist/dice_scores to
locate and update the logic and docstring.
| f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)." | ||
| ) | ||
|
|
||
| first_ch = 0 if self.include_background and not self.per_component else 1 |
There was a problem hiding this comment.
include_background silently ignored when per_component=True.
Line forces first_ch=1 regardless of include_background. Document this or raise a warning if both flags conflict.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@monai/metrics/meandice.py` at line 416, The code currently sets first_ch
based on a combined condition which silently ignores include_background when
per_component is True; update the logic in the MeanDice/meandice implementation
to detect the conflicting flags (self.per_component True and
self.include_background False) and emit a clear warning (e.g., warnings.warn or
using the module logger) that include_background will be ignored when
per_component is enabled, then keep the existing behavior for first_ch (set
first_ch=1) to preserve compatibility; reference the attributes
self.include_background, self.per_component and the local variable first_ch so
reviewers can locate and adjust the check and add the warning.
…itai - docstring issues, ignore_empty bug Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
Fixes #8733
Description
A few sentences describing the changes proposed in this pull request.
This PR adds support for connected component-based Dice metric calculation to the existing DiceMetric and DiceHelper classes.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.