-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add parameter to DiceMetric and DiceHelper classes #8774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
b5107da
ccca77a
c110e2a
41e52c1
34a6817
8d412a1
cb433a8
ba2e0b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,13 +11,19 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from monai.metrics.utils import do_metric_reduction | ||
| from monai.utils import MetricReduction, deprecated_arg | ||
| from monai.utils.module import optional_import | ||
|
|
||
| from .metric import CumulativeIterationMetric | ||
|
|
||
| distance_transform_edt, has_ndimage = optional_import("scipy.ndimage", name="distance_transform_edt") | ||
| generate_binary_structure, _ = optional_import("scipy.ndimage", name="generate_binary_structure") | ||
| sn_label, _ = optional_import("scipy.ndimage", name="label") | ||
|
|
||
| __all__ = ["DiceMetric", "compute_dice", "DiceHelper"] | ||
|
|
||
|
|
||
|
|
@@ -95,6 +101,9 @@ class DiceMetric(CumulativeIterationMetric): | |
| If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True, | ||
| the index begins at "0", otherwise at "1". It can also take a list of label names. | ||
| The outcome will then be returned as a dictionary. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| """ | ||
|
|
||
|
|
@@ -106,6 +115,7 @@ def __init__( | |
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| return_with_label: bool | list[str] = False, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.include_background = include_background | ||
|
|
@@ -114,13 +124,15 @@ def __init__( | |
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.return_with_label = return_with_label | ||
| self.per_component = per_component | ||
| self.dice_helper = DiceHelper( | ||
| include_background=self.include_background, | ||
| reduction=MetricReduction.NONE, | ||
| get_not_nans=False, | ||
| apply_argmax=False, | ||
| ignore_empty=self.ignore_empty, | ||
| num_classes=self.num_classes, | ||
| per_component=self.per_component, | ||
| ) | ||
|
|
||
| def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] | ||
|
|
@@ -175,6 +187,7 @@ def compute_dice( | |
| include_background: bool = True, | ||
| ignore_empty: bool = True, | ||
| num_classes: int | None = None, | ||
| per_component: bool = False, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Computes Dice score metric for a batch of predictions. This performs the same computation as | ||
|
|
@@ -192,6 +205,9 @@ def compute_dice( | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
|
|
||
| Returns: | ||
| Dice scores per batch and per class, (shape: [batch_size, num_classes]). | ||
|
|
@@ -204,6 +220,7 @@ def compute_dice( | |
| apply_argmax=False, | ||
| ignore_empty=ignore_empty, | ||
| num_classes=num_classes, | ||
| per_component=per_component, | ||
| )(y_pred=y_pred, y=y) | ||
|
|
||
|
|
||
|
|
@@ -246,6 +263,9 @@ class DiceHelper: | |
| num_classes: number of input channels (always including the background). When this is ``None``, | ||
| ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are | ||
| single-channel class indices and the number of classes is not automatically inferred from data. | ||
| per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be | ||
| computed for each connected component in the ground truth, and then averaged. This requires 5D binary | ||
| segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation. | ||
| """ | ||
|
|
||
| @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax") | ||
|
|
@@ -262,6 +282,7 @@ def __init__( | |
| num_classes: int | None = None, | ||
| sigmoid: bool | None = None, | ||
| softmax: bool | None = None, | ||
| per_component: bool = False, | ||
| ) -> None: | ||
| # handling deprecated arguments | ||
| if sigmoid is not None: | ||
|
|
@@ -277,6 +298,81 @@ def __init__( | |
| self.activate = activate | ||
| self.ignore_empty = ignore_empty | ||
| self.num_classes = num_classes | ||
| self.per_component = per_component | ||
|
|
||
| def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): | ||
| """ | ||
| Voronoi assignment to connected components (CPU, single EDT) without cc3d. | ||
| Returns the ID of the nearest component for each voxel. | ||
|
|
||
| Args: | ||
| labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. | ||
| connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26. | ||
| sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Voronoi region IDs (int32) on CPU. | ||
| """ | ||
| if not has_ndimage: | ||
| raise RuntimeError("scipy.ndimage is required for per_component Dice computation.") | ||
| x = np.asarray(labels) | ||
| conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3) | ||
| structure = generate_binary_structure(rank=3, connectivity=conn_rank) | ||
| cc, num = sn_label(x > 0, structure=structure) | ||
| if num == 0: | ||
| return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32) | ||
| edt_input = np.ones(cc.shape, dtype=np.uint8) | ||
| edt_input[cc > 0] = 0 | ||
| indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True) | ||
| voronoi = cc[tuple(indices)] | ||
| return torch.from_numpy(voronoi) | ||
|
|
||
| def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Compute per-component Dice for a single batch item. | ||
|
|
||
| Args: | ||
| y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W). | ||
| y (torch.Tensor): Ground truth with shape (1, 2, D, H, W). | ||
|
|
||
| Returns: | ||
| torch.Tensor: Mean Dice over connected components. | ||
| """ | ||
| data = [] | ||
| if y_pred.ndim == y.ndim: | ||
| y_pred_idx = torch.argmax(y_pred, dim=1) | ||
| y_idx = torch.argmax(y, dim=1) | ||
| else: | ||
| y_pred_idx = y_pred | ||
| y_idx = y | ||
| if y_idx[0].sum() == 0: | ||
| if self.ignore_empty: | ||
| data.append(torch.tensor(float("nan"), device=y_idx.device)) | ||
| elif y_pred_idx.sum() == 0: | ||
| data.append(torch.tensor(1.0, device=y_idx.device)) | ||
| else: | ||
| data.append(torch.tensor(0.0, device=y_idx.device)) | ||
| else: | ||
| cc_assignment = self.compute_voronoi_regions_fast(y_idx[0]) | ||
| uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True) | ||
| nof_components = uniq.numel() | ||
| code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1) | ||
| idx = (inv << 2) | code | ||
| hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4) | ||
| _, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3] | ||
| denom = 2 * tp + fp + fn | ||
| dice_scores = torch.where( | ||
| denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device) | ||
| ) | ||
| data.append(dice_scores.unsqueeze(-1)) | ||
| data = [ | ||
| torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
| data = [ | ||
| torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data | ||
| ] | ||
| data = [x.reshape(-1, 1) for x in data] | ||
| return torch.stack([x.mean() for x in data]) | ||
|
Comment on lines
+330
to
+375
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring states wrong shape;
Fix docstring and respect ignore_empty def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
- Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
- for each batch item and for each channel of those items.
+ Compute per-component Dice for a single batch item.
Args:
- y_pred: input predictions with shape HW[D].
- y: ground truth with shape HW[D].
+ y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W).
+ y (torch.Tensor): Ground truth with shape (1, 2, D, H, W).
+
+ Returns:
+ torch.Tensor: Mean Dice over connected components.
"""
data = []
if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
- if y_pred_idx.sum() == 0:
- data.append(torch.tensor(1.0, device=y_idx.device))
- else:
- data.append(torch.tensor(0.0, device=y_idx.device))
+ if self.ignore_empty:
+ data.append(torch.tensor(float("nan"), device=y_idx.device))
+ elif y_pred_idx.sum() == 0:
+ data.append(torch.tensor(1.0, device=y_idx.device))
+ else:
+ data.append(torch.tensor(0.0, device=y_idx.device))
else:🤖 Prompt for AI Agents |
||
|
|
||
| def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
|
|
@@ -322,15 +418,24 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl | |
| y_pred = torch.sigmoid(y_pred) | ||
| y_pred = y_pred > 0.5 | ||
|
|
||
| first_ch = 0 if self.include_background else 1 | ||
| if self.per_component and (len(y_pred.shape) != 5 or y_pred.shape[1] != 2): | ||
| raise ValueError( | ||
| f"per_component requires 5D binary segmentation with 2 channels (background + foreground). " | ||
| f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)." | ||
| ) | ||
|
|
||
| first_ch = 0 if self.include_background and not self.per_component else 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Line forces 🤖 Prompt for AI Agents |
||
| data = [] | ||
| for b in range(y_pred.shape[0]): | ||
| c_list = [] | ||
| for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]: | ||
| x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool() | ||
| x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c] | ||
| c_list.append(self.compute_channel(x_pred, x)) | ||
| if self.per_component: | ||
| c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))] | ||
| data.append(torch.stack(c_list)) | ||
|
|
||
| data = torch.stack(data, dim=0).contiguous() # type: ignore | ||
|
|
||
| f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring incomplete; return always on CPU.
Returns:section and type annotations per coding guidelines.torch.from_numpy()regardless of input device.Docstring fix
def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None): """ Voronoi assignment to connected components (CPU, single EDT) without cc3d. - Returns the ID of the nearest component for each voxel. Args: - labels: input label map as a numpy array, where values > 0 are considered seeds for connected components. - connectivity: 6/18/26 (3D) - sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt) + labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds. + connectivity (int): 6, 18, or 26 for 3D connectivity. Defaults to 26. + sampling (tuple[float, ...] | None): Voxel spacing for anisotropic distances. + + Returns: + torch.Tensor: Voronoi region IDs (int32) on CPU. """📝 Committable suggestion
🤖 Prompt for AI Agents