From 773f5fca41e2a409a7515db7741d00d1bcf6040c Mon Sep 17 00:00:00 2001 From: openhands Date: Wed, 24 Jun 2026 20:10:19 +0000 Subject: [PATCH] Fix DiceFocalLoss to apply activation before removing background When using include_background=False with softmax=True or sigmoid=True in DiceFocalLoss for binary segmentation, the activation was being ignored because the background channel was removed before the activation was applied. This fix applies the activation (softmax/sigmoid/other_act) BEFORE removing the background channel, ensuring that the activation is applied to all channels as intended. The fix: 1. Stores sigmoid, softmax, and other_act as instance variables 2. Applies activation in forward() before removing background 3. Disables activation in the internal DiceLoss instance to avoid double application Fixes: https://github.com/Project-MONAI/MONAI/issues/5697 --- monai/losses/dice.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 869dd05ac2..e4965d297e 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -803,9 +803,9 @@ def __init__( """ super().__init__() self.dice = DiceLoss( - sigmoid=sigmoid, - softmax=softmax, - other_act=other_act, + sigmoid=False, + softmax=False, + other_act=None, squared_pred=squared_pred, jaccard=jaccard, reduction=reduction, @@ -822,6 +822,9 @@ def __init__( self.lambda_focal = lambda_focal self.to_onehot_y = to_onehot_y self.include_background = include_background + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -846,6 +849,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: else: target = one_hot(target, num_classes=n_pred_ch) + # Apply activation before removing background to ensure softmax/sigmoid works correctly + if self.sigmoid: + input = torch.sigmoid(input) + elif self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + elif self.other_act is not None: + input = self.other_act(input) + if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.")