diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 2265dd3a3f..f55f92db1b 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .absolute_volume_difference import AbsoluteVolumeDifferenceMetric, compute_absolute_volume_difference from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score from .average_precision import AveragePrecisionMetric, compute_average_precision from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning diff --git a/monai/metrics/absolute_volume_difference.py b/monai/metrics/absolute_volume_difference.py new file mode 100644 index 0000000000..3ea9937005 --- /dev/null +++ b/monai/metrics/absolute_volume_difference.py @@ -0,0 +1,183 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.utils import MetricReduction + +from .metric import CumulativeIterationMetric + +__all__ = ["AbsoluteVolumeDifferenceMetric", "compute_absolute_volume_difference"] + + +class AbsoluteVolumeDifferenceMetric(CumulativeIterationMetric): + """ + Compute the Absolute Volume Difference (AVD) between predicted and ground-truth + segmentation masks. + + AVD measures the absolute difference in the number of foreground voxels between + prediction and ground truth, per class. It is particularly useful for small-object + segmentation (e.g. retinal fluid in OCT volumes) where Dice score is known to be + overly sensitive to volume size and does not directly reflect volume discrepancies. + + Reference: + Bogunovic et al. (2019). RETOUCH: The Retinal OCT Fluid Detection and + Segmentation Benchmark and Challenge. + IEEE Transactions on Medical Imaging, 38(8), 1858-1874. + https://ieeexplore.ieee.org/document/8653407 + + The inputs ``y_pred`` and ``y`` are expected to be binarized one-hot tensors with + shape BCHW[D]. If they contain continuous values (e.g. sigmoid outputs), binarize + them first with a suitable threshold transform. + + The typical execution steps of this metric class follow + :py:class:`monai.metrics.metric.Cumulative`. + + Example: + + .. code-block:: python + + import torch + from monai.metrics import AbsoluteVolumeDifferenceMetric + + batch_size, n_classes = 4, 3 + y_pred = torch.randint(0, 2, (batch_size, n_classes, 64, 64, 32)).float() + y = torch.randint(0, 2, (batch_size, n_classes, 64, 64, 32)).float() + + metric = AbsoluteVolumeDifferenceMetric(include_background=False) + metric(y_pred, y) # accumulate + result = metric.aggregate() # shape: (n_classes - 1,) after mean reduction + metric.reset() + + Args: + include_background: whether to include AVD computation on the first channel + (index 0), which is by convention assumed to be background. Defaults to + ``True``. Set to ``False`` when the background class dominates and you only + care about foreground classes (e.g. fluid sub-types in OCT). + reduction: defines how to aggregate per-batch-per-class results. Available + modes are enumerated in :py:class:`monai.utils.enums.MetricReduction`. + Defaults to ``"mean"``. + get_not_nans: if ``True``, :meth:`aggregate` returns ``(metric, not_nans)`` + where ``not_nans`` counts the number of valid (non-NaN) values. + Defaults to ``False``. + ignore_empty: if ``True``, cases where the ground-truth channel is entirely + empty (zero voxels) are excluded from aggregation by setting their value + to ``NaN``. If ``False``, the raw absolute difference (equal to the + predicted volume for that class) is returned. Defaults to ``True``. + """ + + def __init__( + self, + include_background: bool = True, + reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + ignore_empty: bool = True, + ) -> None: + """Initialize AbsoluteVolumeDifferenceMetric. See class docstring for argument descriptions.""" + super().__init__() + self.include_background = include_background + self.reduction = reduction + self.get_not_nans = get_not_nans + self.ignore_empty = ignore_empty + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] + """ + Args: + y_pred: binarized prediction tensor, shape BCHW[D]. + y: binarized ground-truth tensor, shape BCHW[D]. + + Raises: + ValueError: when ``y_pred`` has fewer than three dimensions. + """ + if y_pred.ndimension() < 3: + raise ValueError( + f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {y_pred.ndimension()}." + ) + return compute_absolute_volume_difference( + y_pred=y_pred, + y=y, + include_background=self.include_background, + ignore_empty=self.ignore_empty, + ) + + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Execute reduction logic for the accumulated AVD values. + + Args: + reduction: optional override for the reduction mode set at construction. + """ + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be a PyTorch Tensor.") + + f, not_nans = do_metric_reduction(data, reduction or self.reduction) + return (f, not_nans) if self.get_not_nans else f + + +def compute_absolute_volume_difference( + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + ignore_empty: bool = True, +) -> torch.Tensor: + """ + Compute the Absolute Volume Difference (AVD) for a batch of segmentation predictions. + + AVD is defined per class as:: + + AVD_c = | sum_{spatial}(y_pred_c) - sum_{spatial}(y_c) | + + where the sum counts the number of foreground voxels in each channel. + + Args: + y_pred: binarized prediction tensor with shape BCHW[D]. + y: binarized ground-truth tensor with shape BCHW[D]. + include_background: whether to include the first channel (background). + Defaults to ``True``. + ignore_empty: if ``True``, entries where the ground-truth channel contains no + foreground voxels are set to ``NaN`` so they are excluded during reduction. + Defaults to ``True``. + + Returns: + AVD per batch item and per class, shape ``[batch_size, num_classes]``. + + Raises: + ValueError: when ``y_pred`` and ``y`` have different shapes. + """ + if y_pred.ndim < 3: + raise ValueError( + f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {y_pred.ndim}." + ) + + if not include_background: + y_pred, y = ignore_background(y_pred=y_pred, y=y) + + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y should have the same shape, got {y_pred.shape} and {y.shape}.") + + # sum over all spatial dimensions; keep batch (dim 0) and channel (dim 1) + reduce_axis = list(range(2, y_pred.ndim)) + vol_pred = torch.sum(y_pred, dim=reduce_axis) # [B, C] + vol_true = torch.sum(y, dim=reduce_axis) # [B, C] + + avd = torch.abs(vol_pred - vol_true) # [B, C] + + if ignore_empty: + # mark cases with no ground-truth foreground as NaN + avd = torch.where(vol_true > 0, avd, torch.tensor(float("nan"), device=avd.device)) + + return avd diff --git a/tests/metrics/test_absolute_volume_difference.py b/tests/metrics/test_absolute_volume_difference.py new file mode 100644 index 0000000000..2c37839217 --- /dev/null +++ b/tests/metrics/test_absolute_volume_difference.py @@ -0,0 +1,174 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.metrics import AbsoluteVolumeDifferenceMetric, compute_absolute_volume_difference + + +class TestComputeAbsoluteVolumeDifference(unittest.TestCase): + """Tests for the standalone compute_absolute_volume_difference function.""" + + def test_perfect_prediction_returns_zero(self): + """Identical prediction and ground truth should yield AVD of zero for all classes.""" + # identical masks → AVD = 0 for every class + y = torch.zeros(2, 3, 4, 4) + y[:, 1, :2, :2] = 1.0 + y[:, 2, 2:, 2:] = 1.0 + result = compute_absolute_volume_difference(y_pred=y, y=y, ignore_empty=False) + self.assertEqual(result.shape, torch.Size([2, 3])) + self.assertTrue(torch.all(result == 0.0)) + + def test_known_volume_difference(self): + """AVD should equal the absolute difference in foreground voxel counts between prediction and GT.""" + # batch=1, 2 classes (background + foreground), 1D spatial of length 10 + y_pred = torch.zeros(1, 2, 10) + y_true = torch.zeros(1, 2, 10) + y_pred[0, 1, :7] = 1.0 # 7 foreground voxels predicted + y_true[0, 1, :4] = 1.0 # 4 foreground voxels in GT + result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, ignore_empty=False) + # channel 0: both all-zeros → AVD = 0 + # channel 1: |7 - 4| = 3 + self.assertAlmostEqual(result[0, 0].item(), 0.0) + self.assertAlmostEqual(result[0, 1].item(), 3.0) + + def test_ignore_background(self): + """Setting include_background=False should strip the first channel and reduce output shape accordingly.""" + y_pred = torch.zeros(2, 3, 8, 8) + y_true = torch.zeros(2, 3, 8, 8) + y_pred[:, 1, :3, :3] = 1.0 + y_true[:, 1, :4, :4] = 1.0 + result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, include_background=False) + # background channel stripped → shape [2, 2] + self.assertEqual(result.shape, torch.Size([2, 2])) + + def test_ignore_empty_sets_nan(self): + """Channels with no ground-truth foreground voxels should be NaN when ignore_empty=True.""" + # channel 1 has no GT voxels → should be NaN when ignore_empty=True + y_pred = torch.zeros(1, 2, 6) + y_true = torch.zeros(1, 2, 6) + y_pred[0, 0, :3] = 1.0 + result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, ignore_empty=True) + # channel 0: GT is empty → NaN + self.assertTrue(torch.isnan(result[0, 0])) + # channel 1: GT is empty → NaN + self.assertTrue(torch.isnan(result[0, 1])) + + def test_ignore_empty_false_returns_pred_volume(self): + """With ignore_empty=False and empty GT, AVD should equal the predicted volume.""" + # when GT is all zero and ignore_empty=False, AVD = |V_pred - 0| = V_pred + y_pred = torch.zeros(1, 2, 6) + y_true = torch.zeros(1, 2, 6) + y_pred[0, 1, :5] = 1.0 + result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, ignore_empty=False) + self.assertAlmostEqual(result[0, 1].item(), 5.0) + + def test_shape_mismatch_raises(self): + """Mismatched y_pred and y shapes should raise a ValueError.""" + with self.assertRaises(ValueError): + compute_absolute_volume_difference( + y_pred=torch.zeros(2, 3, 8, 8), + y=torch.zeros(2, 3, 4, 4), + ) + + def test_too_few_dims_raises(self): + """Input tensors with fewer than 3 dimensions should raise a ValueError.""" + with self.assertRaises(ValueError): + compute_absolute_volume_difference( + y_pred=torch.zeros(2, 3), + y=torch.zeros(2, 3), + ) + + def test_3d_volumes(self): + """AVD should correctly count voxel differences in 3-D spatial inputs.""" + # 3-D spatial (D, H, W) + y_pred = torch.zeros(1, 2, 8, 8, 8) + y_true = torch.zeros(1, 2, 8, 8, 8) + y_pred[0, 1, :4, :4, :4] = 1.0 # 64 voxels + y_true[0, 1, :3, :3, :3] = 1.0 # 27 voxels + result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, ignore_empty=False) + self.assertAlmostEqual(result[0, 1].item(), 37.0) + + def test_output_shape_multi_class(self): + """Output shape should be [batch_size, num_classes] for multi-class inputs.""" + y = torch.randint(0, 2, (4, 5, 16, 16)).float() + result = compute_absolute_volume_difference(y_pred=y, y=y, ignore_empty=False) + self.assertEqual(result.shape, torch.Size([4, 5])) + + +class TestAbsoluteVolumeDifferenceMetric(unittest.TestCase): + """Tests for the AbsoluteVolumeDifferenceMetric class (cumulative interface).""" + + def test_aggregate_mean(self): + """Mean reduction over accumulated batches should return the correct per-class AVD.""" + y_pred = torch.zeros(2, 2, 8, 8) + y_true = torch.zeros(2, 2, 8, 8) + y_pred[:, 1, :6, :6] = 1.0 # 36 voxels per batch item + y_true[:, 1, :4, :4] = 1.0 # 16 voxels per batch item + metric = AbsoluteVolumeDifferenceMetric(include_background=False, reduction="mean", ignore_empty=False) + metric(y_pred, y_true) + agg = metric.aggregate() + # single foreground channel, AVD = 20 for both batch items → mean = 20 + self.assertAlmostEqual(agg.item(), 20.0) + metric.reset() + + def test_aggregate_returns_not_nans_when_requested(self): + """When get_not_nans=True, aggregate should return a (metric, not_nans) tuple.""" + y_pred = torch.zeros(2, 2, 4, 4) + y_true = torch.zeros(2, 2, 4, 4) + y_pred[:, 1, :2, :2] = 1.0 + y_true[:, 1, :2, :2] = 1.0 + metric = AbsoluteVolumeDifferenceMetric(include_background=False, get_not_nans=True) + metric(y_pred, y_true) + result, not_nans = metric.aggregate() + self.assertIsInstance(result, torch.Tensor) + self.assertIsInstance(not_nans, torch.Tensor) + metric.reset() + + def test_cumulative_accumulation(self): + """Multiple forward calls before aggregate should use all accumulated data correctly.""" + # calling the metric twice and aggregating should use all accumulated data + metric = AbsoluteVolumeDifferenceMetric(include_background=False, reduction="mean", ignore_empty=False) + for _ in range(3): + y_pred = torch.zeros(1, 2, 8) + y_true = torch.zeros(1, 2, 8) + y_pred[0, 1, :6] = 1.0 + y_true[0, 1, :4] = 1.0 + metric(y_pred, y_true) + agg = metric.aggregate() + self.assertAlmostEqual(agg.item(), 2.0) + metric.reset() + + def test_reset_clears_buffer(self): + """Calling reset() should clear the buffer so a subsequent aggregate() raises.""" + metric = AbsoluteVolumeDifferenceMetric(ignore_empty=False) + y = torch.zeros(1, 2, 4) + y[0, 1, :2] = 1.0 + metric(y, y) + metric.reset() + # after reset the buffer should be empty; calling aggregate raises + with self.assertRaises(Exception): + metric.aggregate() + + def test_imported_from_top_level(self): + """AbsoluteVolumeDifferenceMetric should be importable from the monai.metrics top-level namespace.""" + # ensure the class is accessible from monai.metrics top-level + from monai.metrics import AbsoluteVolumeDifferenceMetric as _AVD + + self.assertIs(_AVD, AbsoluteVolumeDifferenceMetric) + + +if __name__ == "__main__": + unittest.main()