Skip to content

Vectorize DiceHelper.__call__()#8764

Open
aymuos15 wants to merge 6 commits intoProject-MONAI:devfrom
aymuos15:optimize/vectorize-dice-metric
Open

Vectorize DiceHelper.__call__()#8764
aymuos15 wants to merge 6 commits intoProject-MONAI:devfrom
aymuos15:optimize/vectorize-dice-metric

Conversation

@aymuos15
Copy link
Contributor

@aymuos15 aymuos15 commented Mar 3, 2026

Summary

  • Replace nested batch/channel Python loops in DiceHelper.__call__() with vectorized torch operations
  • Remove dead compute_channel method (no longer called after vectorization)

Note on memory

The current implementation vectorizes across both batch and channel dimensions simultaneously, which increases peak memory for large 3D volumes. If this becomes an issue, we can switch to looping over batch while keeping channels vectorized — this retains most of the speedup while keeping memory proportional to a single sample.

Test plan

  • All existing test_compute_meandice tests pass
  • New test cases added covering mixed-format edge cases:
    • Single-channel y (class indices) + multi-channel y_pred (one-hot)
    • Single-channel y_pred (argmaxed, with num_classes) + multi-channel y (one-hot)
    • Same as above with include_background=False and batched inputs

aymuos15 added 3 commits March 3, 2026 17:09
Replace nested batch/channel loops with vectorized torch operations.

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
No longer called after vectorization of __call__().

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Convert y_pred and y to boolean independently based on each tensor's
own channel count, fixing incorrect Dice values when formats differ
(e.g. single-channel class indices paired with multi-channel one-hot).

Add test cases covering both mixed-format combinations.

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 3, 2026

📝 Walkthrough

Walkthrough

The PR refactors DiceHelper in monai/metrics/meandice.py by removing the per-channel compute_channel method and implementing a vectorized Dice computation inside DiceHelper.__call__. Predictions and targets are converted to boolean masks (supporting multi-channel one-hot and single-channel class indices), spatial dims are flattened, and per-(batch,channel) intersections and sums are computed in a single vectorized operation. Background inclusion/exclusion and ignore_empty semantics (producing NaNs or per-case values) are preserved. Tests are extended with three mixed-format cases exercising one-hot/argmax predictions and class-index targets.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The description covers the main changes, memory trade-offs, and test plan. However, it lacks the required template structure (issue reference, types of changes checkboxes, test verification statuses). Follow the repository's PR description template: add issue reference (Fixes #...), complete all required checkboxes under Types of changes, and verify test/documentation requirements.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: vectorizing DiceHelper.call() to replace nested loops with vectorized torch operations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/metrics/test_compute_meandice.py (1)

318-321: Add one mixed-format test for ignore_empty=False.

Current mixed cases only validate default ignore_empty=True; the new ignore_empty=False branch in DiceHelper.__call__ remains untested for mixed channel formats.

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/metrics/test_compute_meandice.py` around lines 318 - 321, The test
suite lacks coverage for the new ignore_empty=False branch in
DiceHelper.__call__; add a unit test that calls compute_dice with a mixed-format
case and ignore_empty=False to assert the NaN behavior. Modify or add a
parameterized test (similar to test_nans) that includes a mixed-format test case
(e.g., one of TEST_CASE_MIXED_2/TEST_CASE_MIXED_3 or a new
TEST_CASE_MIXED_IGNORE_EMPTY) and pass ignore_empty=False in the compute_dice
call, then assert np.allclose(np.isnan(result.cpu().numpy()), expected_value)
with the expected mask for the ignore_empty=False semantics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/metrics/meandice.py`:
- Around line 311-328: The code currently casts multi-channel
predictions/targets to bool (y_pred_bool, y_bool), which destroys
soft/probabilistic values; instead, preserve float-valued soft labels for
multi-channel inputs: only perform one-hot expansion when inputs are
single-channel integer class indices (when y_pred.shape[1] == 1 and n_pred_ch >
1 or y.shape[1] == 1 and n_pred_ch > 1), but for multi-channel float tensors
leave them as float tensors (no .bool()), and then reshape to (batch_size,
n_pred_ch, -1) and .float() for downstream Dice computation (y_pred_flat,
y_flat). Update the branches that assign y_pred_bool and y_bool to produce
float-preserving tensors and rename/keep intended variable names used later
(y_pred_bool/y_bool -> still usable as channel-wise float masks) so the rest of
the code (y_pred_flat, y_flat) receives float soft labels instead of binary
values.

---

Nitpick comments:
In `@tests/metrics/test_compute_meandice.py`:
- Around line 318-321: The test suite lacks coverage for the new
ignore_empty=False branch in DiceHelper.__call__; add a unit test that calls
compute_dice with a mixed-format case and ignore_empty=False to assert the NaN
behavior. Modify or add a parameterized test (similar to test_nans) that
includes a mixed-format test case (e.g., one of
TEST_CASE_MIXED_2/TEST_CASE_MIXED_3 or a new TEST_CASE_MIXED_IGNORE_EMPTY) and
pass ignore_empty=False in the compute_dice call, then assert
np.allclose(np.isnan(result.cpu().numpy()), expected_value) with the expected
mask for the ignore_empty=False semantics.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 583d5ca and 39c06c3.

📒 Files selected for processing (2)
  • monai/metrics/meandice.py
  • tests/metrics/test_compute_meandice.py

aymuos15 added 3 commits March 3, 2026 17:47
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
@aymuos15 aymuos15 force-pushed the optimize/vectorize-dice-metric branch from 39c06c3 to 260d644 Compare March 3, 2026 17:53
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
monai/metrics/meandice.py (1)

311-316: ⚠️ Potential issue | 🟠 Major

y_pred soft-label values still lost via .bool() on multi-channel inputs.

Line 316 converts multi-channel y_pred to boolean, which destroys soft/probabilistic values. The class docstring explicitly states soft labels are permitted. While y handling (line 324) was fixed, y_pred still has this issue.

Consider preserving float values for multi-channel y_pred:

Proposed fix
         if y_pred.shape[1] == 1 and n_pred_ch > 1:
-            y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device)
-            for c in range(n_pred_ch):
-                y_pred_bool[:, c] = y_pred[:, 0] == c
+            y_pred_expanded = torch.nn.functional.one_hot(
+                y_pred[:, 0].long(), num_classes=n_pred_ch
+            ).movedim(-1, 1).to(device=device, dtype=torch.float32)
         else:
-            y_pred_bool = y_pred.bool()
+            y_pred_expanded = y_pred.float()

,

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 311 - 316, The multi-channel y_pred
branch currently calls .bool() which discards soft/probabilistic values; change
the logic so multi-channel predictions preserve float (soft) values instead of
converting to boolean: leave y_pred untouched when n_pred_ch > 1 (e.g., set
y_pred_bool = y_pred) and only use .bool() for single-channel/hard-label cases
(or when y_pred is integral), keeping the existing
single-channel-to-multi-channel one-hot conversion for y_pred.shape[1] == 1;
update references to y_pred_bool accordingly so downstream code handles float
soft-label tensors correctly (symbols: y_pred, y_pred_bool, n_pred_ch,
batch_size, device).
🧹 Nitpick comments (2)
monai/metrics/meandice.py (2)

281-289: Docstring missing Returns section.

Per coding guidelines, docstrings should document return values. This method returns either torch.Tensor or tuple[torch.Tensor, torch.Tensor] depending on get_not_nans.

Suggested addition
         y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
+
+        Returns:
+            torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Dice scores per batch/channel.
+                If ``get_not_nans`` is True, returns ``(scores, not_nans)`` tuple.
         """

As per coding guidelines: "Docstrings should be present for all definition which describe each variable, return value, and raised exception."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 281 - 289, Update the __call__ method
docstring to include a Returns section that documents the return types and
conditions: specify it returns a torch.Tensor of per-batch or aggregated Dice
scores and, if get_not_nans is True, returns a tuple (torch.Tensor,
torch.Tensor) where the second tensor contains the counts/flags of non-NaN
entries; include shapes/axes semantics (e.g., per-class or aggregated depending
on num_classes and reduction) and clarify when the tuple is produced (based on
get_not_nans) and the dtype (torch.Tensor) to match the signature of __call__.

319-324: Consider torch.nn.functional.one_hot for cleaner expansion.

The explicit loop works but one_hot is more idiomatic and potentially faster.

Suggested refactor
         if y.shape[1] == 1 and n_pred_ch > 1:
-            y_expanded = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.float32, device=device)
-            for c in range(n_pred_ch):
-                y_expanded[:, c] = (y[:, 0] == c).float()
+            y_expanded = torch.nn.functional.one_hot(
+                y[:, 0].long(), num_classes=n_pred_ch
+            ).movedim(-1, 1).to(device=device, dtype=torch.float32)
         else:
             y_expanded = y
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/metrics/meandice.py` around lines 319 - 324, Replace the explicit
Python loop that expands a single-channel label tensor with
torch.nn.functional.one_hot: when y.shape[1] == 1 and n_pred_ch > 1, call
F.one_hot on y[:,0].long() with num_classes=n_pred_ch, then convert to float,
move to device, and reshape/permute so the result becomes y_expanded of shape
(batch_size, n_pred_ch, *y.shape[2:]) — this keeps the same semantics as the
loop but is more idiomatic and faster; ensure you use the same dtype/device as
other tensors (referenced symbols: y_expanded, y, n_pred_ch, device,
batch_size).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@monai/metrics/meandice.py`:
- Around line 311-316: The multi-channel y_pred branch currently calls .bool()
which discards soft/probabilistic values; change the logic so multi-channel
predictions preserve float (soft) values instead of converting to boolean: leave
y_pred untouched when n_pred_ch > 1 (e.g., set y_pred_bool = y_pred) and only
use .bool() for single-channel/hard-label cases (or when y_pred is integral),
keeping the existing single-channel-to-multi-channel one-hot conversion for
y_pred.shape[1] == 1; update references to y_pred_bool accordingly so downstream
code handles float soft-label tensors correctly (symbols: y_pred, y_pred_bool,
n_pred_ch, batch_size, device).

---

Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 281-289: Update the __call__ method docstring to include a Returns
section that documents the return types and conditions: specify it returns a
torch.Tensor of per-batch or aggregated Dice scores and, if get_not_nans is
True, returns a tuple (torch.Tensor, torch.Tensor) where the second tensor
contains the counts/flags of non-NaN entries; include shapes/axes semantics
(e.g., per-class or aggregated depending on num_classes and reduction) and
clarify when the tuple is produced (based on get_not_nans) and the dtype
(torch.Tensor) to match the signature of __call__.
- Around line 319-324: Replace the explicit Python loop that expands a
single-channel label tensor with torch.nn.functional.one_hot: when y.shape[1] ==
1 and n_pred_ch > 1, call F.one_hot on y[:,0].long() with num_classes=n_pred_ch,
then convert to float, move to device, and reshape/permute so the result becomes
y_expanded of shape (batch_size, n_pred_ch, *y.shape[2:]) — this keeps the same
semantics as the loop but is more idiomatic and faster; ensure you use the same
dtype/device as other tensors (referenced symbols: y_expanded, y, n_pred_ch,
device, batch_size).

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 39c06c3 and 260d644.

📒 Files selected for processing (2)
  • monai/metrics/meandice.py
  • tests/metrics/test_compute_meandice.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/metrics/test_compute_meandice.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant