Fix DiceFocalLoss to apply activation before removing background#8947
Fix DiceFocalLoss to apply activation before removing background#8947luomi16 wants to merge 1 commit into
Conversation
When using include_background=False with softmax=True or sigmoid=True in DiceFocalLoss for binary segmentation, the activation was being ignored because the background channel was removed before the activation was applied. This fix applies the activation (softmax/sigmoid/other_act) BEFORE removing the background channel, ensuring that the activation is applied to all channels as intended. The fix: 1. Stores sigmoid, softmax, and other_act as instance variables 2. Applies activation in forward() before removing background 3. Disables activation in the internal DiceLoss instance to avoid double application Fixes: Project-MONAI#5697
📝 WalkthroughWalkthroughDiceFocalLoss now builds its internal DiceLoss with activation disabled, stores its own sigmoid/softmax/other_act settings, and applies that activation in forward before optionally removing the background channel. The one-hot target conversion and single-channel handling remain unchanged. Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/dice.py (1)
852-872: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick winKeep logits for
FocalLoss.
inputis converted to probabilities, then passed toself.focal(...). That breaks the documented logits contract forFocalLoss.Proposed fix
+ dice_input = input + focal_input = input + # Apply activation before removing background to ensure softmax/sigmoid works correctly if self.sigmoid: - input = torch.sigmoid(input) + dice_input = torch.sigmoid(dice_input) elif self.softmax: if n_pred_ch == 1: warnings.warn("single channel prediction, `softmax=True` ignored.") else: - input = torch.softmax(input, 1) + dice_input = torch.softmax(dice_input, 1) elif self.other_act is not None: - input = self.other_act(input) + dice_input = self.other_act(dice_input) if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.") else: # if skipping background, removing first channel target = target[:, 1:] - input = input[:, 1:] + dice_input = dice_input[:, 1:] + focal_input = focal_input[:, 1:] - dice_loss = self.dice(input, target) - focal_loss = self.focal(input, target) + dice_loss = self.dice(dice_input, target) + focal_loss = self.focal(focal_input, target)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/losses/dice.py` around lines 852 - 872, Keep logits intact for the focal path in the dice loss implementation: in the section that applies activation and then calls self.dice(...) and self.focal(...), avoid reusing the activated input for FocalLoss. Compute the dice input with the existing sigmoid/softmax/other_act handling, but pass the original logits (or a separate untouched tensor) into self.focal so the FocalLoss contract remains correct. Use the dice() and focal() calls in dice.py to locate the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/losses/dice.py`:
- Line 857: The warning emitted in the single-channel prediction branch of
`DiceLoss` needs an explicit stack level so the caller sees the correct source
location. Update the `warnings.warn` call in `dice.py` to pass `stacklevel=2`,
keeping the existing message intact and making sure the change is applied in the
`DiceLoss` logic where `softmax=True` is ignored.
- Around line 825-827: DiceLoss is no longer enforcing mutually exclusive
activation settings, so invalid configs like sigmoid=True and softmax=True can
slip through when the activation path is skipped. Restore the exclusivity
validation in DiceLoss initialization or setup logic by checking the sigmoid,
softmax, and other_act flags together and raising an error when more than one
activation is enabled; use the DiceLoss constructor/validation flow to locate
and fix this.
---
Outside diff comments:
In `@monai/losses/dice.py`:
- Around line 852-872: Keep logits intact for the focal path in the dice loss
implementation: in the section that applies activation and then calls
self.dice(...) and self.focal(...), avoid reusing the activated input for
FocalLoss. Compute the dice input with the existing sigmoid/softmax/other_act
handling, but pass the original logits (or a separate untouched tensor) into
self.focal so the FocalLoss contract remains correct. Use the dice() and focal()
calls in dice.py to locate the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6c110aae-3952-4268-976c-23e1523294ce
📒 Files selected for processing (1)
monai/losses/dice.py
| self.sigmoid = sigmoid | ||
| self.softmax = softmax | ||
| self.other_act = other_act |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Preserve activation exclusivity validation.
With DiceLoss activation disabled, configs like sigmoid=True, softmax=True are now silently accepted and resolved by if/elif.
Proposed fix
self.to_onehot_y = to_onehot_y
self.include_background = include_background
+ if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
+ raise ValueError("Only one of sigmoid=True, softmax=True, or other_act may be specified.")
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_actAs per path instructions, "Examine code for logical error or inconsistencies".
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.sigmoid = sigmoid | |
| self.softmax = softmax | |
| self.other_act = other_act | |
| self.to_onehot_y = to_onehot_y | |
| self.include_background = include_background | |
| if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: | |
| raise ValueError("Only one of sigmoid=True, softmax=True, or other_act may be specified.") | |
| self.sigmoid = sigmoid | |
| self.softmax = softmax | |
| self.other_act = other_act |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/losses/dice.py` around lines 825 - 827, DiceLoss is no longer enforcing
mutually exclusive activation settings, so invalid configs like sigmoid=True and
softmax=True can slip through when the activation path is skipped. Restore the
exclusivity validation in DiceLoss initialization or setup logic by checking the
sigmoid, softmax, and other_act flags together and raising an error when more
than one activation is enabled; use the DiceLoss constructor/validation flow to
locate and fix this.
Source: Path instructions
| input = torch.sigmoid(input) | ||
| elif self.softmax: | ||
| if n_pred_ch == 1: | ||
| warnings.warn("single channel prediction, `softmax=True` ignored.") |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Set stacklevel on the new warning.
Ruff flags this warnings.warn call; use stacklevel=2 so callers see their call site.
Proposed fix
- warnings.warn("single channel prediction, `softmax=True` ignored.")
+ warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| warnings.warn("single channel prediction, `softmax=True` ignored.") | |
| warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2) |
🧰 Tools
🪛 Ruff (0.15.18)
[warning] 857-857: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/losses/dice.py` at line 857, The warning emitted in the single-channel
prediction branch of `DiceLoss` needs an explicit stack level so the caller sees
the correct source location. Update the `warnings.warn` call in `dice.py` to
pass `stacklevel=2`, keeping the existing message intact and making sure the
change is applied in the `DiceLoss` logic where `softmax=True` is ignored.
Source: Linters/SAST tools
Summary
This PR fixes a bug in
DiceFocalLosswhere usinginclude_background=Falsewithsoftmax=Trueorsigmoid=Truefor binary segmentation would cause the activation function to be ignored.Problem
When using
include_background=Falsewithsoftmax=TrueinDiceFocalLossfor binary segmentation, the Dice part would produce this warning and not apply softmax when calculating the Dice loss:This happened because the background channel was removed before the activation was applied, leaving only one channel for the Dice loss to process.
Solution
The fix applies the activation (softmax/sigmoid/other_act) before removing the background channel, ensuring that the activation is applied to all channels as intended.
Changes made:
sigmoid,softmax, andother_actas instance variables inDiceFocalLoss.__init__()forward()before removing backgroundDiceLossinstance to avoid double applicationTesting
include_background=Falsewithsoftmax=Trueorsigmoid=TrueRelated Issue
Fixes: #5697