diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py index 017606ac08..680ff7bc82 100644 --- a/monai/losses/hausdorff_loss.py +++ b/monai/losses/hausdorff_loss.py @@ -83,7 +83,7 @@ def __init__( super().__init__(reduction=LossReduction(reduction).value) if other_act is not None and not callable(other_act): raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") - if int(sigmoid) + int(softmax) > 1: + if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") self.alpha = alpha diff --git a/tests/losses/test_hausdorff_loss.py b/tests/losses/test_hausdorff_loss.py index f2211008c2..afc9dbe135 100644 --- a/tests/losses/test_hausdorff_loss.py +++ b/tests/losses/test_hausdorff_loss.py @@ -212,6 +212,12 @@ def test_ill_shape(self): def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): HausdorffDTLoss(sigmoid=True, softmax=True) + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(sigmoid=True, other_act=torch.tanh) + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(softmax=True, other_act=torch.tanh) + with self.assertRaisesRegex(ValueError, ""): + HausdorffDTLoss(sigmoid=True, softmax=True, other_act=torch.tanh) chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertRaisesRegex(ValueError, ""): @@ -244,6 +250,12 @@ def test_ill_shape(self): def test_ill_opts(self): with self.assertRaisesRegex(ValueError, ""): LogHausdorffDTLoss(sigmoid=True, softmax=True) + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(sigmoid=True, other_act=torch.tanh) + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(softmax=True, other_act=torch.tanh) + with self.assertRaisesRegex(ValueError, ""): + LogHausdorffDTLoss(sigmoid=True, softmax=True, other_act=torch.tanh) chn_input = torch.ones((1, 1, 3)) chn_target = torch.ones((1, 1, 3)) with self.assertRaisesRegex(ValueError, ""):