Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 55 additions & 28 deletions monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,25 +278,6 @@ def __init__(
self.ignore_empty = ignore_empty
self.num_classes = num_classes

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
for each batch item and for each channel of those items.

Args:
y_pred: input predictions with shape HW[D].
y: ground truth with shape HW[D].
"""
y_o = torch.sum(y)
if y_o > 0:
return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))
if self.ignore_empty:
return torch.tensor(float("nan"), device=y_o.device)
denorm = y_o + torch.sum(y_pred)
if denorm <= 0:
return torch.tensor(1.0, device=y_o.device)
return torch.tensor(0.0, device=y_o.device)

def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Compute the metric for the given prediction and ground truth.
Expand All @@ -322,16 +303,62 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5

# Vectorized computation (replaces nested loops for better performance)
batch_size = y_pred.shape[0]
device = y_pred.device

# Convert y_pred to boolean (handle single-channel class indices vs multi-channel one-hot independently)
if y_pred.shape[1] == 1 and n_pred_ch > 1:
y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device)
for c in range(n_pred_ch):
y_pred_bool[:, c] = y_pred[:, 0] == c
else:
y_pred_bool = y_pred.bool()

# Convert y: single-channel class indices → one-hot bool; multi-channel → preserve raw values
if y.shape[1] == 1 and n_pred_ch > 1:
y_expanded = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.float32, device=device)
for c in range(n_pred_ch):
y_expanded[:, c] = (y[:, 0] == c).float()
else:
y_expanded = y

# Flatten spatial dimensions for vectorized computation: (batch, channels, -1)
y_pred_flat = y_pred_bool.reshape(batch_size, n_pred_ch, -1).float()
y_flat = y_expanded.reshape(batch_size, n_pred_ch, -1).float()

# Compute Dice per (batch, channel) vectorized: all reductions at once
intersection = torch.sum(y_pred_flat * y_flat, dim=-1) # (batch, n_pred_ch)
pred_sum = torch.sum(y_pred_flat, dim=-1) # (batch, n_pred_ch)
y_sum = torch.sum(y_flat, dim=-1) # (batch, n_pred_ch)

# Dice formula: 2 * intersection / (pred_sum + y_sum)
union = pred_sum + y_sum
dice = (2.0 * intersection) / union # (batch, n_pred_ch)

# Handle empty ground truth cases
if self.ignore_empty:
# Set NaN where ground truth is empty
dice = torch.where(y_sum > 0, dice, torch.tensor(float("nan"), device=device, dtype=dice.dtype))
else:
# Set 1.0 if both empty, 0.0 if only pred is non-empty
empty_mask = y_sum == 0
dice = torch.where(
empty_mask,
torch.where(
pred_sum == 0,
torch.tensor(1.0, device=device, dtype=dice.dtype),
torch.tensor(0.0, device=device, dtype=dice.dtype),
),
dice,
)

# Select channels: exclude background if requested
first_ch = 0 if self.include_background else 1
data = []
for b in range(y_pred.shape[0]):
c_list = []
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
c_list.append(self.compute_channel(x_pred, x))
data.append(torch.stack(c_list))
data = torch.stack(data, dim=0).contiguous() # type: ignore
if n_pred_ch > 1:
data = dice[:, first_ch:] # (batch, num_classes_selected)
else:
data = dice # (batch, 1)

f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
return (f, not_nans) if self.get_not_nans else f
55 changes: 53 additions & 2 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,66 @@
]


# single-channel y (class indices) with multi-channel y_pred (one-hot)
TEST_CASE_MIXED_1 = [
{
"y_pred": torch.tensor(
[[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]]]]
), # (1, 3, 2, 2) one-hot
"y": torch.tensor([[[[0.0, 1.0], [2.0, 1.0]]]]), # (1, 1, 2, 2) class indices
"include_background": True,
},
# class 0: y_gt=[[1,0],[0,0]], y_pred=[[0,1],[0,0]] -> dice=0.0
# class 1: y_gt=[[0,1],[0,1]], y_pred=[[0,0],[0,1]] -> dice=2/3
# class 2: y_gt=[[0,0],[1,0]], y_pred=[[1,0],[1,0]] -> dice=2/3
[[0.0000, 0.6667, 0.6667]],
]

# single-channel y_pred (argmaxed, with num_classes) with multi-channel y (one-hot)
TEST_CASE_MIXED_2 = [
{
"y_pred": torch.tensor([[[[2.0, 2.0], [2.0, 2.0]]]]), # (1, 1, 2, 2) all class 2
"y": torch.tensor(
[[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]
), # (1, 3, 2, 2) one-hot, all background
"include_background": True,
"num_classes": 3,
},
# class 0: y_gt=[1,1,1,1](4), y_pred=[0,0,0,0](0) -> dice=0.0
# class 1: y_gt=[0,0,0,0](0), y_pred=[0,0,0,0](0) -> dice=nan (ignore_empty default)
# class 2: y_gt=[0,0,0,0](0), y_pred=[1,1,1,1](4) -> dice=nan (ignore_empty default)
[[False, True, True]], # False=not-nan, True=nan
]

# single-channel y (class indices) with multi-channel y_pred, exclude background
TEST_CASE_MIXED_3 = [
{
"y_pred": torch.tensor(
[
[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
), # (2, 3, 2, 2) one-hot
"y": torch.tensor([[[[0.0, 0.0], [0.0, 1.0]]], [[[0.0, 0.0], [0.0, 1.0]]]]), # (2, 1, 2, 2) class indices
"include_background": False,
},
# batch 0: class 1 y_gt=[[0,0],[0,1]], y_pred=[[0,0],[1,1]] -> dice=2/3
# class 2 y_gt=[[0,0],[0,0]], y_pred=[[1,0],[0,0]] -> dice=nan
# batch 1: class 1 y_gt=[[0,0],[0,1]], y_pred=[[1,0],[0,0]] -> dice=0.0
# class 2 y_gt=[[0,0],[0,0]], y_pred=[[0,1],[1,0]] -> dice=nan
[[False, True], [False, True]], # nan pattern
]


class TestComputeMeanDice(unittest.TestCase):

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12, TEST_CASE_MIXED_1])
def test_value(self, input_data, expected_value):
result = compute_dice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
np.testing.assert_equal(result.device, input_data["y_pred"].device)

@parameterized.expand([TEST_CASE_3])
@parameterized.expand([TEST_CASE_3, TEST_CASE_MIXED_2, TEST_CASE_MIXED_3])
def test_nans(self, input_data, expected_value):
result = compute_dice(**input_data)
self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))
Expand Down
Loading