diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index ddd3a350d5..a2bfe21c93 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -155,7 +155,10 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): if not _use_compiled: # pytorch native grid_sample for i, dim in enumerate(grid.shape[1:-1]): - grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 + # guard against a singleton spatial dim (e.g. a single-slice volume), where + # ``dim - 1 == 0`` would divide by zero; clamp the denominator to 1 so the lone + # voxel maps to -1, matching ``monai.networks.utils.normalize_transform``. + grid[..., i] = grid[..., i] * 2 / max(dim - 1, 1) - 1 index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1)) grid = grid[..., index_ordering] # z, y, x -> x, y, z return F.grid_sample( diff --git a/tests/networks/blocks/warp/test_warp.py b/tests/networks/blocks/warp/test_warp.py index 93af559790..6ee4230783 100644 --- a/tests/networks/blocks/warp/test_warp.py +++ b/tests/networks/blocks/warp/test_warp.py @@ -12,6 +12,7 @@ import unittest from pathlib import Path +from unittest import mock import numpy as np import torch @@ -138,6 +139,28 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, ""): warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3)) + @mock.patch("monai.networks.blocks.warp.USE_COMPILED", False) + def test_singleton_spatial_dim(self): + """ + Regression test for a singleton spatial dimension (a single-slice volume or a + single-row/column image), where the grid normalization ``* 2 / (dim - 1)`` previously + divided by zero. + + The native ``grid_sample`` path is forced via ``USE_COMPILED=False`` because only that + branch normalizes the grid; the csrc ``grid_pull`` path is unaffected. ``padding_mode`` + is ``"zeros"`` so an out-of-range (pre-fix ``nan``) coordinate maps to 0 and exposes the + bug, whereas ``"border"``/``"reflection"`` would clamp onto the lone voxel and mask it. + For a zero displacement field the warped output must contain no ``nan`` and must equal + the input image. + """ + for shape, ndim in [((1, 1, 1, 4, 4), 3), ((1, 1, 1, 5), 2), ((1, 1, 5, 1), 2)]: + image = torch.rand(*shape) + ddf = torch.zeros(shape[0], ndim, *shape[2:]) + warp_layer = Warp(mode="bilinear", padding_mode="zeros") + result = warp_layer(image, ddf) + self.assertFalse(torch.isnan(result).any(), f"NaN in warp output for shape {shape}") + np.testing.assert_allclose(result.cpu().numpy(), image.cpu().numpy(), rtol=1e-4, atol=1e-4) + def test_grad(self): for b in GridSampleMode: for p in GridSamplePadMode: