diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 869dd05ac2..e4965d297e 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -803,9 +803,9 @@ def __init__( """ super().__init__() self.dice = DiceLoss( - sigmoid=sigmoid, - softmax=softmax, - other_act=other_act, + sigmoid=False, + softmax=False, + other_act=None, squared_pred=squared_pred, jaccard=jaccard, reduction=reduction, @@ -822,6 +822,9 @@ def __init__( self.lambda_focal = lambda_focal self.to_onehot_y = to_onehot_y self.include_background = include_background + self.sigmoid = sigmoid + self.softmax = softmax + self.other_act = other_act def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -846,6 +849,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: else: target = one_hot(target, num_classes=n_pred_ch) + # Apply activation before removing background to ensure softmax/sigmoid works correctly + if self.sigmoid: + input = torch.sigmoid(input) + elif self.softmax: + if n_pred_ch == 1: + warnings.warn("single channel prediction, `softmax=True` ignored.") + else: + input = torch.softmax(input, 1) + elif self.other_act is not None: + input = self.other_act(input) + if not self.include_background: if n_pred_ch == 1: warnings.warn("single channel prediction, `include_background=False` ignored.")