From 8b2e31693b890a52d982a12bbd6ce4e343e7dc13 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 3 Mar 2026 17:09:26 +0000 Subject: [PATCH 1/5] perf: Vectorize DiceHelper.__call__() Replace nested batch/channel loops with vectorized torch operations. Signed-off-by: Soumya Snigdha Kundu --- monai/metrics/meandice.py | 63 +++++++++++++++++++++++++++++++++------ 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index fedd94fb93..57fb257b9e 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -322,16 +322,61 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl y_pred = torch.sigmoid(y_pred) y_pred = y_pred > 0.5 + # Vectorized computation (replaces nested loops for better performance) + batch_size = y_pred.shape[0] + device = y_pred.device + + # Convert to boolean for computation + if y_pred.shape[1] == 1 and n_pred_ch > 1: + # Single-channel class indices: convert to one-hot + y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device) + y_bool = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.bool, device=device) + + for c in range(n_pred_ch): + y_pred_bool[:, c] = (y_pred[:, 0] == c) + y_bool[:, c] = (y[:, 0] == c) + else: + # One-hot format: cast to bool + y_pred_bool = y_pred.bool() + if y.shape[1] == 1 and y_pred.shape[1] > 1: + # Expand y to match y_pred channels + y_bool = (y == 1).expand(batch_size, n_pred_ch, *y.shape[2:]) + else: + y_bool = y.bool() + + # Flatten spatial dimensions for vectorized computation: (batch, channels, -1) + y_pred_flat = y_pred_bool.reshape(batch_size, n_pred_ch, -1).float() + y_flat = y_bool.reshape(batch_size, n_pred_ch, -1).float() + + # Compute Dice per (batch, channel) vectorized: all reductions at once + intersection = torch.sum(y_pred_flat * y_flat, dim=-1) # (batch, n_pred_ch) + pred_sum = torch.sum(y_pred_flat, dim=-1) # (batch, n_pred_ch) + y_sum = torch.sum(y_flat, dim=-1) # (batch, n_pred_ch) + + # Dice formula: 2 * intersection / (pred_sum + y_sum) + union = pred_sum + y_sum + dice = (2.0 * intersection) / union # (batch, n_pred_ch) + + # Handle empty ground truth cases + if self.ignore_empty: + # Set NaN where ground truth is empty + dice = torch.where(y_sum > 0, dice, torch.tensor(float("nan"), device=device, dtype=dice.dtype)) + else: + # Set 1.0 if both empty, 0.0 if only pred is non-empty + empty_mask = y_sum == 0 + dice = torch.where( + empty_mask, + torch.where(pred_sum == 0, torch.tensor(1.0, device=device, dtype=dice.dtype), + torch.tensor(0.0, device=device, dtype=dice.dtype)), + dice + ) + + # Select channels: exclude background if requested first_ch = 0 if self.include_background else 1 - data = [] - for b in range(y_pred.shape[0]): - c_list = [] - for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: - x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() - x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] - c_list.append(self.compute_channel(x_pred, x)) - data.append(torch.stack(c_list)) - data = torch.stack(data, dim=0).contiguous() # type: ignore + if n_pred_ch > 1: + data = dice[:, first_ch:] # (batch, num_classes_selected) + else: + data = dice # (batch, 1) f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore return (f, not_nans) if self.get_not_nans else f From 5d1456ef84263454373be2c460342281dcab4d01 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 3 Mar 2026 17:19:55 +0000 Subject: [PATCH 2/5] refactor: Remove dead compute_channel method from DiceHelper No longer called after vectorization of __call__(). Signed-off-by: Soumya Snigdha Kundu --- monai/metrics/meandice.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 57fb257b9e..00882014b2 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -278,25 +278,6 @@ def __init__( self.ignore_empty = ignore_empty self.num_classes = num_classes - def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """ - Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately - for each batch item and for each channel of those items. - - Args: - y_pred: input predictions with shape HW[D]. - y: ground truth with shape HW[D]. - """ - y_o = torch.sum(y) - if y_o > 0: - return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred)) - if self.ignore_empty: - return torch.tensor(float("nan"), device=y_o.device) - denorm = y_o + torch.sum(y_pred) - if denorm <= 0: - return torch.tensor(1.0, device=y_o.device) - return torch.tensor(0.0, device=y_o.device) - def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Compute the metric for the given prediction and ground truth. From 3c27d3fea6e3651079e2d3428721d138731126d2 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 3 Mar 2026 17:25:23 +0000 Subject: [PATCH 3/5] fix: Handle mixed single/multi-channel y and y_pred in DiceHelper Convert y_pred and y to boolean independently based on each tensor's own channel count, fixing incorrect Dice values when formats differ (e.g. single-channel class indices paired with multi-channel one-hot). Add test cases covering both mixed-format combinations. Signed-off-by: Soumya Snigdha Kundu --- monai/metrics/meandice.py | 20 ++++----- tests/metrics/test_compute_meandice.py | 60 +++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 00882014b2..30667482c1 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -307,23 +307,21 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl batch_size = y_pred.shape[0] device = y_pred.device - # Convert to boolean for computation + # Convert y_pred to boolean (handle single-channel class indices vs multi-channel one-hot independently) if y_pred.shape[1] == 1 and n_pred_ch > 1: - # Single-channel class indices: convert to one-hot y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device) - y_bool = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.bool, device=device) - for c in range(n_pred_ch): y_pred_bool[:, c] = (y_pred[:, 0] == c) - y_bool[:, c] = (y[:, 0] == c) else: - # One-hot format: cast to bool y_pred_bool = y_pred.bool() - if y.shape[1] == 1 and y_pred.shape[1] > 1: - # Expand y to match y_pred channels - y_bool = (y == 1).expand(batch_size, n_pred_ch, *y.shape[2:]) - else: - y_bool = y.bool() + + # Convert y to boolean (independent of y_pred format) + if y.shape[1] == 1 and n_pred_ch > 1: + y_bool = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.bool, device=device) + for c in range(n_pred_ch): + y_bool[:, c] = (y[:, 0] == c) + else: + y_bool = y.bool() # Flatten spatial dimensions for vectorized computation: (batch, channels, -1) y_pred_flat = y_pred_bool.reshape(batch_size, n_pred_ch, -1).float() diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index 04c81ff9a7..151d9e231f 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -251,15 +251,71 @@ ] +# single-channel y (class indices) with multi-channel y_pred (one-hot) +TEST_CASE_MIXED_1 = [ + { + "y_pred": torch.tensor( + [[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]]]] + ), # (1, 3, 2, 2) one-hot + "y": torch.tensor([[[[0.0, 1.0], [2.0, 1.0]]]]), # (1, 1, 2, 2) class indices + "include_background": True, + }, + # class 0: y_gt=[[1,0],[0,0]], y_pred=[[0,1],[0,0]] -> dice=0.0 + # class 1: y_gt=[[0,1],[0,1]], y_pred=[[0,0],[0,1]] -> dice=2/3 + # class 2: y_gt=[[0,0],[1,0]], y_pred=[[1,0],[1,0]] -> dice=2/3 + [[0.0000, 0.6667, 0.6667]], +] + +# single-channel y_pred (argmaxed, with num_classes) with multi-channel y (one-hot) +TEST_CASE_MIXED_2 = [ + { + "y_pred": torch.tensor([[[[2.0, 2.0], [2.0, 2.0]]]]), # (1, 1, 2, 2) all class 2 + "y": torch.tensor( + [[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]] + ), # (1, 3, 2, 2) one-hot, all background + "include_background": True, + "num_classes": 3, + }, + # class 0: y_gt=[1,1,1,1](4), y_pred=[0,0,0,0](0) -> dice=0.0 + # class 1: y_gt=[0,0,0,0](0), y_pred=[0,0,0,0](0) -> dice=nan (ignore_empty default) + # class 2: y_gt=[0,0,0,0](0), y_pred=[1,1,1,1](4) -> dice=nan (ignore_empty default) + [[False, True, True]], # False=not-nan, True=nan +] + +# single-channel y (class indices) with multi-channel y_pred, exclude background +TEST_CASE_MIXED_3 = [ + { + "y_pred": torch.tensor( + [ + [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), # (2, 3, 2, 2) one-hot + "y": torch.tensor( + [ + [[[0.0, 0.0], [0.0, 1.0]]], + [[[0.0, 0.0], [0.0, 1.0]]], + ] + ), # (2, 1, 2, 2) class indices + "include_background": False, + }, + # batch 0: class 1 y_gt=[[0,0],[0,1]], y_pred=[[0,0],[1,1]] -> dice=2/3 + # class 2 y_gt=[[0,0],[0,0]], y_pred=[[1,0],[0,0]] -> dice=nan + # batch 1: class 1 y_gt=[[0,0],[0,1]], y_pred=[[1,0],[0,0]] -> dice=0.0 + # class 2 y_gt=[[0,0],[0,0]], y_pred=[[0,1],[1,0]] -> dice=nan + [[False, True], [False, True]], # nan pattern +] + + class TestComputeMeanDice(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12, TEST_CASE_MIXED_1]) def test_value(self, input_data, expected_value): result = compute_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) np.testing.assert_equal(result.device, input_data["y_pred"].device) - @parameterized.expand([TEST_CASE_3]) + @parameterized.expand([TEST_CASE_3, TEST_CASE_MIXED_2, TEST_CASE_MIXED_3]) def test_nans(self, input_data, expected_value): result = compute_dice(**input_data) self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value)) From b333238189705f99c8b878b647c9252b2ff456a8 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 3 Mar 2026 17:47:54 +0000 Subject: [PATCH 4/5] apply coderabbit suggestion Signed-off-by: Soumya Snigdha Kundu --- monai/metrics/meandice.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 30667482c1..23e394ca19 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -315,17 +315,17 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl else: y_pred_bool = y_pred.bool() - # Convert y to boolean (independent of y_pred format) + # Convert y: single-channel class indices → one-hot bool; multi-channel → preserve raw values if y.shape[1] == 1 and n_pred_ch > 1: - y_bool = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.bool, device=device) + y_expanded = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.float32, device=device) for c in range(n_pred_ch): - y_bool[:, c] = (y[:, 0] == c) + y_expanded[:, c] = (y[:, 0] == c).float() else: - y_bool = y.bool() + y_expanded = y # Flatten spatial dimensions for vectorized computation: (batch, channels, -1) y_pred_flat = y_pred_bool.reshape(batch_size, n_pred_ch, -1).float() - y_flat = y_bool.reshape(batch_size, n_pred_ch, -1).float() + y_flat = y_expanded.reshape(batch_size, n_pred_ch, -1).float() # Compute Dice per (batch, channel) vectorized: all reductions at once intersection = torch.sum(y_pred_flat * y_flat, dim=-1) # (batch, n_pred_ch) From 260d644e0c9c2778d6a4b5dc1d5ed5be0f38f881 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 3 Mar 2026 17:51:22 +0000 Subject: [PATCH 5/5] add lint fixes Signed-off-by: Soumya Snigdha Kundu --- monai/metrics/meandice.py | 11 +++++++---- tests/metrics/test_compute_meandice.py | 7 +------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py index 23e394ca19..749c951532 100644 --- a/monai/metrics/meandice.py +++ b/monai/metrics/meandice.py @@ -311,7 +311,7 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl if y_pred.shape[1] == 1 and n_pred_ch > 1: y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device) for c in range(n_pred_ch): - y_pred_bool[:, c] = (y_pred[:, 0] == c) + y_pred_bool[:, c] = y_pred[:, 0] == c else: y_pred_bool = y_pred.bool() @@ -345,9 +345,12 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl empty_mask = y_sum == 0 dice = torch.where( empty_mask, - torch.where(pred_sum == 0, torch.tensor(1.0, device=device, dtype=dice.dtype), - torch.tensor(0.0, device=device, dtype=dice.dtype)), - dice + torch.where( + pred_sum == 0, + torch.tensor(1.0, device=device, dtype=dice.dtype), + torch.tensor(0.0, device=device, dtype=dice.dtype), + ), + dice, ) # Select channels: exclude background if requested diff --git a/tests/metrics/test_compute_meandice.py b/tests/metrics/test_compute_meandice.py index 151d9e231f..5796a1a59a 100644 --- a/tests/metrics/test_compute_meandice.py +++ b/tests/metrics/test_compute_meandice.py @@ -291,12 +291,7 @@ [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]], ] ), # (2, 3, 2, 2) one-hot - "y": torch.tensor( - [ - [[[0.0, 0.0], [0.0, 1.0]]], - [[[0.0, 0.0], [0.0, 1.0]]], - ] - ), # (2, 1, 2, 2) class indices + "y": torch.tensor([[[[0.0, 0.0], [0.0, 1.0]]], [[[0.0, 0.0], [0.0, 1.0]]]]), # (2, 1, 2, 2) class indices "include_background": False, }, # batch 0: class 1 y_gt=[[0,0],[0,1]], y_pred=[[0,0],[1,1]] -> dice=2/3