Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from .absolute_volume_difference import AbsoluteVolumeDifferenceMetric, compute_absolute_volume_difference
from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
from .average_precision import AveragePrecisionMetric, compute_average_precision
from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning
Expand Down
183 changes: 183 additions & 0 deletions monai/metrics/absolute_volume_difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch

from monai.metrics.utils import do_metric_reduction, ignore_background
from monai.utils import MetricReduction

from .metric import CumulativeIterationMetric

__all__ = ["AbsoluteVolumeDifferenceMetric", "compute_absolute_volume_difference"]


class AbsoluteVolumeDifferenceMetric(CumulativeIterationMetric):
"""
Compute the Absolute Volume Difference (AVD) between predicted and ground-truth
segmentation masks.

AVD measures the absolute difference in the number of foreground voxels between
prediction and ground truth, per class. It is particularly useful for small-object
segmentation (e.g. retinal fluid in OCT volumes) where Dice score is known to be
overly sensitive to volume size and does not directly reflect volume discrepancies.

Reference:
Bogunovic et al. (2019). RETOUCH: The Retinal OCT Fluid Detection and
Segmentation Benchmark and Challenge.
IEEE Transactions on Medical Imaging, 38(8), 1858-1874.
https://ieeexplore.ieee.org/document/8653407

The inputs ``y_pred`` and ``y`` are expected to be binarized one-hot tensors with
shape BCHW[D]. If they contain continuous values (e.g. sigmoid outputs), binarize
them first with a suitable threshold transform.

The typical execution steps of this metric class follow
:py:class:`monai.metrics.metric.Cumulative`.

Example:

.. code-block:: python

import torch
from monai.metrics import AbsoluteVolumeDifferenceMetric

batch_size, n_classes = 4, 3
y_pred = torch.randint(0, 2, (batch_size, n_classes, 64, 64, 32)).float()
y = torch.randint(0, 2, (batch_size, n_classes, 64, 64, 32)).float()

metric = AbsoluteVolumeDifferenceMetric(include_background=False)
metric(y_pred, y) # accumulate
result = metric.aggregate() # shape: (n_classes - 1,) after mean reduction
metric.reset()

Args:
include_background: whether to include AVD computation on the first channel
(index 0), which is by convention assumed to be background. Defaults to
``True``. Set to ``False`` when the background class dominates and you only
care about foreground classes (e.g. fluid sub-types in OCT).
reduction: defines how to aggregate per-batch-per-class results. Available
modes are enumerated in :py:class:`monai.utils.enums.MetricReduction`.
Defaults to ``"mean"``.
get_not_nans: if ``True``, :meth:`aggregate` returns ``(metric, not_nans)``
where ``not_nans`` counts the number of valid (non-NaN) values.
Defaults to ``False``.
ignore_empty: if ``True``, cases where the ground-truth channel is entirely
empty (zero voxels) are excluded from aggregation by setting their value
to ``NaN``. If ``False``, the raw absolute difference (equal to the
predicted volume for that class) is returned. Defaults to ``True``.
"""

def __init__(
self,
include_background: bool = True,
reduction: MetricReduction | str = MetricReduction.MEAN,
get_not_nans: bool = False,
ignore_empty: bool = True,
) -> None:
"""Initialize AbsoluteVolumeDifferenceMetric. See class docstring for argument descriptions."""
super().__init__()
self.include_background = include_background
self.reduction = reduction
self.get_not_nans = get_not_nans
self.ignore_empty = ignore_empty

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Args:
y_pred: binarized prediction tensor, shape BCHW[D].
y: binarized ground-truth tensor, shape BCHW[D].

Raises:
ValueError: when ``y_pred`` has fewer than three dimensions.
"""
if y_pred.ndimension() < 3:
raise ValueError(
f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {y_pred.ndimension()}."
)
return compute_absolute_volume_difference(
y_pred=y_pred,
y=y,
include_background=self.include_background,
ignore_empty=self.ignore_empty,
)

def aggregate(
self, reduction: MetricReduction | str | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Execute reduction logic for the accumulated AVD values.

Args:
reduction: optional override for the reduction mode set at construction.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("the data to aggregate must be a PyTorch Tensor.")

f, not_nans = do_metric_reduction(data, reduction or self.reduction)
return (f, not_nans) if self.get_not_nans else f


def compute_absolute_volume_difference(
y_pred: torch.Tensor,
y: torch.Tensor,
include_background: bool = True,
ignore_empty: bool = True,
) -> torch.Tensor:
"""
Compute the Absolute Volume Difference (AVD) for a batch of segmentation predictions.

AVD is defined per class as::

AVD_c = | sum_{spatial}(y_pred_c) - sum_{spatial}(y_c) |

where the sum counts the number of foreground voxels in each channel.

Args:
y_pred: binarized prediction tensor with shape BCHW[D].
y: binarized ground-truth tensor with shape BCHW[D].
include_background: whether to include the first channel (background).
Defaults to ``True``.
ignore_empty: if ``True``, entries where the ground-truth channel contains no
foreground voxels are set to ``NaN`` so they are excluded during reduction.
Defaults to ``True``.

Returns:
AVD per batch item and per class, shape ``[batch_size, num_classes]``.

Raises:
ValueError: when ``y_pred`` and ``y`` have different shapes.
"""
if y_pred.ndim < 3:
raise ValueError(
f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {y_pred.ndim}."
)

if not include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)

if y_pred.shape != y.shape:
raise ValueError(f"y_pred and y should have the same shape, got {y_pred.shape} and {y.shape}.")

# sum over all spatial dimensions; keep batch (dim 0) and channel (dim 1)
reduce_axis = list(range(2, y_pred.ndim))
vol_pred = torch.sum(y_pred, dim=reduce_axis) # [B, C]
vol_true = torch.sum(y, dim=reduce_axis) # [B, C]

avd = torch.abs(vol_pred - vol_true) # [B, C]

if ignore_empty:
# mark cases with no ground-truth foreground as NaN
avd = torch.where(vol_true > 0, avd, torch.tensor(float("nan"), device=avd.device))

return avd
174 changes: 174 additions & 0 deletions tests/metrics/test_absolute_volume_difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch

from monai.metrics import AbsoluteVolumeDifferenceMetric, compute_absolute_volume_difference


class TestComputeAbsoluteVolumeDifference(unittest.TestCase):
"""Tests for the standalone compute_absolute_volume_difference function."""

def test_perfect_prediction_returns_zero(self):
"""Identical prediction and ground truth should yield AVD of zero for all classes."""
# identical masks → AVD = 0 for every class
y = torch.zeros(2, 3, 4, 4)
y[:, 1, :2, :2] = 1.0
y[:, 2, 2:, 2:] = 1.0
result = compute_absolute_volume_difference(y_pred=y, y=y, ignore_empty=False)
self.assertEqual(result.shape, torch.Size([2, 3]))
self.assertTrue(torch.all(result == 0.0))

def test_known_volume_difference(self):
"""AVD should equal the absolute difference in foreground voxel counts between prediction and GT."""
# batch=1, 2 classes (background + foreground), 1D spatial of length 10
y_pred = torch.zeros(1, 2, 10)
y_true = torch.zeros(1, 2, 10)
y_pred[0, 1, :7] = 1.0 # 7 foreground voxels predicted
y_true[0, 1, :4] = 1.0 # 4 foreground voxels in GT
result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, ignore_empty=False)
# channel 0: both all-zeros → AVD = 0
# channel 1: |7 - 4| = 3
self.assertAlmostEqual(result[0, 0].item(), 0.0)
self.assertAlmostEqual(result[0, 1].item(), 3.0)

def test_ignore_background(self):
"""Setting include_background=False should strip the first channel and reduce output shape accordingly."""
y_pred = torch.zeros(2, 3, 8, 8)
y_true = torch.zeros(2, 3, 8, 8)
y_pred[:, 1, :3, :3] = 1.0
y_true[:, 1, :4, :4] = 1.0
result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, include_background=False)
# background channel stripped → shape [2, 2]
self.assertEqual(result.shape, torch.Size([2, 2]))

def test_ignore_empty_sets_nan(self):
"""Channels with no ground-truth foreground voxels should be NaN when ignore_empty=True."""
# channel 1 has no GT voxels → should be NaN when ignore_empty=True
y_pred = torch.zeros(1, 2, 6)
y_true = torch.zeros(1, 2, 6)
y_pred[0, 0, :3] = 1.0
result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, ignore_empty=True)
# channel 0: GT is empty → NaN
self.assertTrue(torch.isnan(result[0, 0]))
# channel 1: GT is empty → NaN
self.assertTrue(torch.isnan(result[0, 1]))

def test_ignore_empty_false_returns_pred_volume(self):
"""With ignore_empty=False and empty GT, AVD should equal the predicted volume."""
# when GT is all zero and ignore_empty=False, AVD = |V_pred - 0| = V_pred
y_pred = torch.zeros(1, 2, 6)
y_true = torch.zeros(1, 2, 6)
y_pred[0, 1, :5] = 1.0
result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, ignore_empty=False)
self.assertAlmostEqual(result[0, 1].item(), 5.0)

def test_shape_mismatch_raises(self):
"""Mismatched y_pred and y shapes should raise a ValueError."""
with self.assertRaises(ValueError):
compute_absolute_volume_difference(
y_pred=torch.zeros(2, 3, 8, 8),
y=torch.zeros(2, 3, 4, 4),
)

def test_too_few_dims_raises(self):
"""Input tensors with fewer than 3 dimensions should raise a ValueError."""
with self.assertRaises(ValueError):
compute_absolute_volume_difference(
y_pred=torch.zeros(2, 3),
y=torch.zeros(2, 3),
)

def test_3d_volumes(self):
"""AVD should correctly count voxel differences in 3-D spatial inputs."""
# 3-D spatial (D, H, W)
y_pred = torch.zeros(1, 2, 8, 8, 8)
y_true = torch.zeros(1, 2, 8, 8, 8)
y_pred[0, 1, :4, :4, :4] = 1.0 # 64 voxels
y_true[0, 1, :3, :3, :3] = 1.0 # 27 voxels
result = compute_absolute_volume_difference(y_pred=y_pred, y=y_true, ignore_empty=False)
self.assertAlmostEqual(result[0, 1].item(), 37.0)

def test_output_shape_multi_class(self):
"""Output shape should be [batch_size, num_classes] for multi-class inputs."""
y = torch.randint(0, 2, (4, 5, 16, 16)).float()
result = compute_absolute_volume_difference(y_pred=y, y=y, ignore_empty=False)
self.assertEqual(result.shape, torch.Size([4, 5]))


class TestAbsoluteVolumeDifferenceMetric(unittest.TestCase):
"""Tests for the AbsoluteVolumeDifferenceMetric class (cumulative interface)."""

def test_aggregate_mean(self):
"""Mean reduction over accumulated batches should return the correct per-class AVD."""
y_pred = torch.zeros(2, 2, 8, 8)
y_true = torch.zeros(2, 2, 8, 8)
y_pred[:, 1, :6, :6] = 1.0 # 36 voxels per batch item
y_true[:, 1, :4, :4] = 1.0 # 16 voxels per batch item
metric = AbsoluteVolumeDifferenceMetric(include_background=False, reduction="mean", ignore_empty=False)
metric(y_pred, y_true)
agg = metric.aggregate()
# single foreground channel, AVD = 20 for both batch items → mean = 20
self.assertAlmostEqual(agg.item(), 20.0)
metric.reset()

def test_aggregate_returns_not_nans_when_requested(self):
"""When get_not_nans=True, aggregate should return a (metric, not_nans) tuple."""
y_pred = torch.zeros(2, 2, 4, 4)
y_true = torch.zeros(2, 2, 4, 4)
y_pred[:, 1, :2, :2] = 1.0
y_true[:, 1, :2, :2] = 1.0
metric = AbsoluteVolumeDifferenceMetric(include_background=False, get_not_nans=True)
metric(y_pred, y_true)
result, not_nans = metric.aggregate()
self.assertIsInstance(result, torch.Tensor)
self.assertIsInstance(not_nans, torch.Tensor)
metric.reset()

def test_cumulative_accumulation(self):
"""Multiple forward calls before aggregate should use all accumulated data correctly."""
# calling the metric twice and aggregating should use all accumulated data
metric = AbsoluteVolumeDifferenceMetric(include_background=False, reduction="mean", ignore_empty=False)
for _ in range(3):
y_pred = torch.zeros(1, 2, 8)
y_true = torch.zeros(1, 2, 8)
y_pred[0, 1, :6] = 1.0
y_true[0, 1, :4] = 1.0
metric(y_pred, y_true)
agg = metric.aggregate()
self.assertAlmostEqual(agg.item(), 2.0)
metric.reset()

def test_reset_clears_buffer(self):
"""Calling reset() should clear the buffer so a subsequent aggregate() raises."""
metric = AbsoluteVolumeDifferenceMetric(ignore_empty=False)
y = torch.zeros(1, 2, 4)
y[0, 1, :2] = 1.0
metric(y, y)
metric.reset()
# after reset the buffer should be empty; calling aggregate raises
with self.assertRaises(Exception):
metric.aggregate()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def test_imported_from_top_level(self):
"""AbsoluteVolumeDifferenceMetric should be importable from the monai.metrics top-level namespace."""
# ensure the class is accessible from monai.metrics top-level
from monai.metrics import AbsoluteVolumeDifferenceMetric as _AVD

self.assertIs(_AVD, AbsoluteVolumeDifferenceMetric)


if __name__ == "__main__":
unittest.main()
Loading