Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):

if not _use_compiled: # pytorch native grid_sample
for i, dim in enumerate(grid.shape[1:-1]):
grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
# guard against a singleton spatial dim (e.g. a single-slice volume), where
# ``dim - 1 == 0`` would divide by zero; clamp the denominator to 1 so the lone
# voxel maps to -1, matching ``monai.networks.utils.normalize_transform``.
grid[..., i] = grid[..., i] * 2 / max(dim - 1, 1) - 1
index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1))
grid = grid[..., index_ordering] # z, y, x -> x, y, z
return F.grid_sample(
Expand Down
23 changes: 23 additions & 0 deletions tests/networks/blocks/warp/test_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import unittest
from pathlib import Path
from unittest import mock

import numpy as np
import torch
Expand Down Expand Up @@ -138,6 +139,28 @@ def test_ill_shape(self):
with self.assertRaisesRegex(ValueError, ""):
warp_layer(image=torch.arange(4).reshape((1, 1, 2, 2)).to(dtype=torch.float), ddf=torch.zeros(1, 2, 3, 3))

@mock.patch("monai.networks.blocks.warp.USE_COMPILED", False)
def test_singleton_spatial_dim(self):
"""
Regression test for a singleton spatial dimension (a single-slice volume or a
single-row/column image), where the grid normalization ``* 2 / (dim - 1)`` previously
divided by zero.

The native ``grid_sample`` path is forced via ``USE_COMPILED=False`` because only that
branch normalizes the grid; the csrc ``grid_pull`` path is unaffected. ``padding_mode``
is ``"zeros"`` so an out-of-range (pre-fix ``nan``) coordinate maps to 0 and exposes the
bug, whereas ``"border"``/``"reflection"`` would clamp onto the lone voxel and mask it.
For a zero displacement field the warped output must contain no ``nan`` and must equal
the input image.
"""
for shape, ndim in [((1, 1, 1, 4, 4), 3), ((1, 1, 1, 5), 2), ((1, 1, 5, 1), 2)]:
image = torch.rand(*shape)
ddf = torch.zeros(shape[0], ndim, *shape[2:])
warp_layer = Warp(mode="bilinear", padding_mode="zeros")
result = warp_layer(image, ddf)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
self.assertFalse(torch.isnan(result).any(), f"NaN in warp output for shape {shape}")
np.testing.assert_allclose(result.cpu().numpy(), image.cpu().numpy(), rtol=1e-4, atol=1e-4)

def test_grad(self):
for b in GridSampleMode:
for p in GridSamplePadMode:
Expand Down
Loading