Conversation
Replace nested batch/channel loops with vectorized torch operations. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
No longer called after vectorization of __call__(). Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Convert y_pred and y to boolean independently based on each tensor's own channel count, fixing incorrect Dice values when formats differ (e.g. single-channel class indices paired with multi-channel one-hot). Add test cases covering both mixed-format combinations. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
📝 WalkthroughWalkthroughThe PR refactors DiceHelper in monai/metrics/meandice.py by removing the per-channel Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/metrics/test_compute_meandice.py (1)
318-321: Add one mixed-format test forignore_empty=False.Current mixed cases only validate default
ignore_empty=True; the newignore_empty=Falsebranch inDiceHelper.__call__remains untested for mixed channel formats.As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/metrics/test_compute_meandice.py` around lines 318 - 321, The test suite lacks coverage for the new ignore_empty=False branch in DiceHelper.__call__; add a unit test that calls compute_dice with a mixed-format case and ignore_empty=False to assert the NaN behavior. Modify or add a parameterized test (similar to test_nans) that includes a mixed-format test case (e.g., one of TEST_CASE_MIXED_2/TEST_CASE_MIXED_3 or a new TEST_CASE_MIXED_IGNORE_EMPTY) and pass ignore_empty=False in the compute_dice call, then assert np.allclose(np.isnan(result.cpu().numpy()), expected_value) with the expected mask for the ignore_empty=False semantics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/metrics/meandice.py`:
- Around line 311-328: The code currently casts multi-channel
predictions/targets to bool (y_pred_bool, y_bool), which destroys
soft/probabilistic values; instead, preserve float-valued soft labels for
multi-channel inputs: only perform one-hot expansion when inputs are
single-channel integer class indices (when y_pred.shape[1] == 1 and n_pred_ch >
1 or y.shape[1] == 1 and n_pred_ch > 1), but for multi-channel float tensors
leave them as float tensors (no .bool()), and then reshape to (batch_size,
n_pred_ch, -1) and .float() for downstream Dice computation (y_pred_flat,
y_flat). Update the branches that assign y_pred_bool and y_bool to produce
float-preserving tensors and rename/keep intended variable names used later
(y_pred_bool/y_bool -> still usable as channel-wise float masks) so the rest of
the code (y_pred_flat, y_flat) receives float soft labels instead of binary
values.
---
Nitpick comments:
In `@tests/metrics/test_compute_meandice.py`:
- Around line 318-321: The test suite lacks coverage for the new
ignore_empty=False branch in DiceHelper.__call__; add a unit test that calls
compute_dice with a mixed-format case and ignore_empty=False to assert the NaN
behavior. Modify or add a parameterized test (similar to test_nans) that
includes a mixed-format test case (e.g., one of
TEST_CASE_MIXED_2/TEST_CASE_MIXED_3 or a new TEST_CASE_MIXED_IGNORE_EMPTY) and
pass ignore_empty=False in the compute_dice call, then assert
np.allclose(np.isnan(result.cpu().numpy()), expected_value) with the expected
mask for the ignore_empty=False semantics.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/metrics/meandice.pytests/metrics/test_compute_meandice.py
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
39c06c3 to
260d644
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
monai/metrics/meandice.py (1)
311-316:⚠️ Potential issue | 🟠 Majory_pred soft-label values still lost via
.bool()on multi-channel inputs.Line 316 converts multi-channel
y_predto boolean, which destroys soft/probabilistic values. The class docstring explicitly states soft labels are permitted. Whileyhandling (line 324) was fixed,y_predstill has this issue.Consider preserving float values for multi-channel
y_pred:Proposed fix
if y_pred.shape[1] == 1 and n_pred_ch > 1: - y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device) - for c in range(n_pred_ch): - y_pred_bool[:, c] = y_pred[:, 0] == c + y_pred_expanded = torch.nn.functional.one_hot( + y_pred[:, 0].long(), num_classes=n_pred_ch + ).movedim(-1, 1).to(device=device, dtype=torch.float32) else: - y_pred_bool = y_pred.bool() + y_pred_expanded = y_pred.float(),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 311 - 316, The multi-channel y_pred branch currently calls .bool() which discards soft/probabilistic values; change the logic so multi-channel predictions preserve float (soft) values instead of converting to boolean: leave y_pred untouched when n_pred_ch > 1 (e.g., set y_pred_bool = y_pred) and only use .bool() for single-channel/hard-label cases (or when y_pred is integral), keeping the existing single-channel-to-multi-channel one-hot conversion for y_pred.shape[1] == 1; update references to y_pred_bool accordingly so downstream code handles float soft-label tensors correctly (symbols: y_pred, y_pred_bool, n_pred_ch, batch_size, device).
🧹 Nitpick comments (2)
monai/metrics/meandice.py (2)
281-289: Docstring missingReturnssection.Per coding guidelines, docstrings should document return values. This method returns either
torch.Tensorortuple[torch.Tensor, torch.Tensor]depending onget_not_nans.Suggested addition
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...). + + Returns: + torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Dice scores per batch/channel. + If ``get_not_nans`` is True, returns ``(scores, not_nans)`` tuple. """As per coding guidelines: "Docstrings should be present for all definition which describe each variable, return value, and raised exception."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 281 - 289, Update the __call__ method docstring to include a Returns section that documents the return types and conditions: specify it returns a torch.Tensor of per-batch or aggregated Dice scores and, if get_not_nans is True, returns a tuple (torch.Tensor, torch.Tensor) where the second tensor contains the counts/flags of non-NaN entries; include shapes/axes semantics (e.g., per-class or aggregated depending on num_classes and reduction) and clarify when the tuple is produced (based on get_not_nans) and the dtype (torch.Tensor) to match the signature of __call__.
319-324: Considertorch.nn.functional.one_hotfor cleaner expansion.The explicit loop works but
one_hotis more idiomatic and potentially faster.Suggested refactor
if y.shape[1] == 1 and n_pred_ch > 1: - y_expanded = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.float32, device=device) - for c in range(n_pred_ch): - y_expanded[:, c] = (y[:, 0] == c).float() + y_expanded = torch.nn.functional.one_hot( + y[:, 0].long(), num_classes=n_pred_ch + ).movedim(-1, 1).to(device=device, dtype=torch.float32) else: y_expanded = y🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/metrics/meandice.py` around lines 319 - 324, Replace the explicit Python loop that expands a single-channel label tensor with torch.nn.functional.one_hot: when y.shape[1] == 1 and n_pred_ch > 1, call F.one_hot on y[:,0].long() with num_classes=n_pred_ch, then convert to float, move to device, and reshape/permute so the result becomes y_expanded of shape (batch_size, n_pred_ch, *y.shape[2:]) — this keeps the same semantics as the loop but is more idiomatic and faster; ensure you use the same dtype/device as other tensors (referenced symbols: y_expanded, y, n_pred_ch, device, batch_size).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@monai/metrics/meandice.py`:
- Around line 311-316: The multi-channel y_pred branch currently calls .bool()
which discards soft/probabilistic values; change the logic so multi-channel
predictions preserve float (soft) values instead of converting to boolean: leave
y_pred untouched when n_pred_ch > 1 (e.g., set y_pred_bool = y_pred) and only
use .bool() for single-channel/hard-label cases (or when y_pred is integral),
keeping the existing single-channel-to-multi-channel one-hot conversion for
y_pred.shape[1] == 1; update references to y_pred_bool accordingly so downstream
code handles float soft-label tensors correctly (symbols: y_pred, y_pred_bool,
n_pred_ch, batch_size, device).
---
Nitpick comments:
In `@monai/metrics/meandice.py`:
- Around line 281-289: Update the __call__ method docstring to include a Returns
section that documents the return types and conditions: specify it returns a
torch.Tensor of per-batch or aggregated Dice scores and, if get_not_nans is
True, returns a tuple (torch.Tensor, torch.Tensor) where the second tensor
contains the counts/flags of non-NaN entries; include shapes/axes semantics
(e.g., per-class or aggregated depending on num_classes and reduction) and
clarify when the tuple is produced (based on get_not_nans) and the dtype
(torch.Tensor) to match the signature of __call__.
- Around line 319-324: Replace the explicit Python loop that expands a
single-channel label tensor with torch.nn.functional.one_hot: when y.shape[1] ==
1 and n_pred_ch > 1, call F.one_hot on y[:,0].long() with num_classes=n_pred_ch,
then convert to float, move to device, and reshape/permute so the result becomes
y_expanded of shape (batch_size, n_pred_ch, *y.shape[2:]) — this keeps the same
semantics as the loop but is more idiomatic and faster; ensure you use the same
dtype/device as other tensors (referenced symbols: y_expanded, y, n_pred_ch,
device, batch_size).
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/metrics/meandice.pytests/metrics/test_compute_meandice.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/metrics/test_compute_meandice.py
Summary
DiceHelper.__call__()with vectorized torch operationscompute_channelmethod (no longer called after vectorization)Note on memory
The current implementation vectorizes across both batch and channel dimensions simultaneously, which increases peak memory for large 3D volumes. If this becomes an issue, we can switch to looping over batch while keeping channels vectorized — this retains most of the speedup while keeping memory proportional to a single sample.
Test plan
test_compute_meandicetests passy(class indices) + multi-channely_pred(one-hot)y_pred(argmaxed, withnum_classes) + multi-channely(one-hot)include_background=Falseand batched inputs