From 74e7ca4233a50733f87264d92cb2a6bd281a6d1b Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Mon, 2 Mar 2026 09:58:05 +0000 Subject: [PATCH 1/8] Add explicit spatial_ndim tracking to MetaTensor (Fixes #6397) Fixes dimension-mismatch crashes when einops.rearrange() or other reshape operations change tensor ndim by decoupling spatial rank from tensor shape. - Add _spatial_ndim attribute to MetaObj, derived from affine in MetaTensor - Expose spatial_ndim property with getter/setter and validation - Sync spatial_ndim on affine assignment and propagate through collation - Update transforms to use spatial_ndim instead of ndim-1 heuristic - Add 18 new tests for spatial_ndim behavior Signed-off-by: Soumya Snigdha Kundu --- monai/data/__init__.py | 2 +- monai/data/meta_obj.py | 1 + monai/data/meta_tensor.py | 54 ++++++-- monai/data/utils.py | 1 + monai/transforms/croppad/functional.py | 4 +- monai/transforms/intensity/array.py | 14 +- monai/transforms/inverse.py | 2 +- monai/transforms/lazy/functional.py | 10 +- monai/transforms/post/array.py | 6 +- monai/transforms/spatial/array.py | 25 ++-- monai/transforms/spatial/functional.py | 4 +- monai/transforms/utility/array.py | 36 ++--- tests/data/meta_tensor/test_meta_tensor.py | 1 + tests/data/meta_tensor/test_spatial_ndim.py | 139 ++++++++++++++++++++ tests/transforms/test_squeezedim.py | 1 + tests/transforms/utility/test_splitdim.py | 34 +++++ 16 files changed, 281 insertions(+), 53 deletions(-) create mode 100644 tests/data/meta_tensor/test_spatial_ndim.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 5e367cc297..fe59e6eb8a 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -71,7 +71,7 @@ monai_to_itk_ddf, ) from .meta_obj import MetaObj, get_track_meta, set_track_meta -from .meta_tensor import MetaTensor +from .meta_tensor import MetaTensor, get_spatial_ndim from .samplers import DistributedSampler, DistributedWeightedRandomSampler from .synthetic import create_test_image_2d, create_test_image_3d from .test_time_augmentation import TestTimeAugmentation diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 15e6e8be15..d903e68d0b 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -84,6 +84,7 @@ def __init__(self) -> None: self._applied_operations: list = MetaObj.get_default_applied_operations() self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops self._is_batch: bool = False + self._spatial_ndim: int = 3 # default: 3 spatial dimensions @staticmethod def flatten_meta_objs(*args: Iterable): diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 12bd76ba60..e6de8d1a1e 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -21,14 +21,25 @@ import torch import monai -from monai.config.type_definitions import NdarrayTensor +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor -__all__ = ["MetaTensor"] +__all__ = ["MetaTensor", "get_spatial_ndim"] + + +def get_spatial_ndim(img: NdarrayOrTensor) -> int: + """Return the number of spatial dimensions assuming channel-first layout. + + Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to + ``img.ndim - 1``. + """ + if isinstance(img, MetaTensor): + return img.spatial_ndim + return img.ndim - 1 @functools.lru_cache(None) @@ -111,6 +122,7 @@ def __new__( meta: dict | None = None, applied_operations: list | None = None, *args, + spatial_ndim: int | None = None, **kwargs, ) -> MetaTensor: _kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {} @@ -123,6 +135,7 @@ def __init__( meta: dict | None = None, applied_operations: list | None = None, *_args, + spatial_ndim: int | None = None, **_kwargs, ) -> None: """ @@ -134,6 +147,8 @@ def __init__( the list is typically maintained by `monai.transforms.TraceableTransform`. See also: :py:class:`monai.transforms.TraceableTransform` _args: additional args (currently not in use in this constructor). + spatial_ndim: optional number of spatial dimensions. If ``None``, derived + from the affine matrix clamped by the tensor shape. _kwargs: additional kwargs (currently not in use in this constructor). Note: @@ -158,6 +173,12 @@ def __init__( self.affine = self.meta[MetaKeys.AFFINE] else: self.affine = self.get_default_affine() + # derive spatial_ndim from affine, clamped by tensor shape + if spatial_ndim is not None: + self.spatial_ndim = spatial_ndim + elif self.affine.ndim == 2: + self.spatial_ndim = min(self.affine.shape[-1] - 1, max(self.ndim - 1, 1)) + # applied_operations if applied_operations is not None: self.applied_operations = applied_operations @@ -468,14 +489,29 @@ def affine(self) -> torch.Tensor: @affine.setter def affine(self, d: NdarrayTensor) -> None: """Set the affine.""" - self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) + a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) + self.meta[MetaKeys.AFFINE] = a + if a.ndim == 2: # non-batched: sync spatial_ndim + self.spatial_ndim = a.shape[-1] - 1 + + @property + def spatial_ndim(self) -> int: + """Get the number of spatial dimensions.""" + return getattr(self, "_spatial_ndim", 3) + + @spatial_ndim.setter + def spatial_ndim(self, val: int) -> None: + """Set the number of spatial dimensions.""" + if val < 1: + raise ValueError(f"spatial_ndim must be >= 1, got {val}") + self._spatial_ndim = val @property def pixdim(self): """Get the spacing""" if self.is_batch: - return [affine_to_spacing(a) for a in self.affine] - return affine_to_spacing(self.affine) + return [affine_to_spacing(a, r=self.spatial_ndim) for a in self.affine] + return affine_to_spacing(self.affine, r=self.spatial_ndim) def peek_pending_shape(self): """ @@ -490,7 +526,7 @@ def peek_pending_shape(self): def peek_pending_affine(self): res = self.affine - r = len(res) - 1 + r = res.shape[-1] - 1 if res.ndim >= 2 else self.spatial_ndim if r not in (2, 3): warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.") for p in self.pending_operations: @@ -503,8 +539,10 @@ def peek_pending_affine(self): return res def peek_pending_rank(self): - a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine - return 1 if a is None else int(max(1, len(a) - 1)) + if self.pending_operations: + a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) + return 1 if a is None else int(max(1, len(a) - 1)) + return self.spatial_ndim def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override] """ diff --git a/monai/data/utils.py b/monai/data/utils.py index 4e5a3bd7f6..35d19e59ee 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -432,6 +432,7 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None): collated.meta = default_collate(meta_dicts) collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] collated.is_batch = True + collated.spatial_ndim = getattr(batch[0], "spatial_ndim", 3) # assumes uniform spatial_ndim return collated diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index acf42849d3..378f1cf688 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -22,7 +22,7 @@ from monai.config.type_definitions import NdarrayTensor from monai.data.meta_obj import get_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.data.utils import to_affine_nd from monai.transforms.inverse import TraceableTransform from monai.transforms.utils import convert_pad_mode, create_translate @@ -132,7 +132,7 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, mode: the padding mode. kwargs: other arguments for the `np.pad` or `torch.pad` function. """ - ndim = len(img.shape) - 1 + ndim = get_spatial_ndim(img) matrix_np = np.round(to_affine_nd(ndim, convert_to_numpy(translation_mat, wrap_sequence=True).copy())) matrix_np = to_affine_nd(len(spatial_size), matrix_np) cc = np.asarray(np.meshgrid(*[[0.5, x - 0.5] for x in spatial_size], indexing="ij")) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 0421d34492..6adce0b49e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -26,6 +26,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import get_spatial_ndim from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter @@ -1580,7 +1581,7 @@ def __init__(self, radius: Sequence[int] | int = 1) -> None: def __call__(self, img: NdarrayTensor) -> NdarrayTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) - spatial_dims = img_t.ndim - 1 + spatial_dims = get_spatial_ndim(img) r = ensure_tuple_rep(self.radius, spatial_dims) median_filter_instance = MedianFilter(r, spatial_dims=spatial_dims) out_t: torch.Tensor = median_filter_instance(img_t) @@ -1616,7 +1617,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma] else: sigma = torch.as_tensor(self.sigma, device=img_t.device) - gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx) + gaussian_filter = GaussianFilter(get_spatial_ndim(img), sigma, approx=self.approx) out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype) @@ -1673,7 +1674,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if not self._do_transform: return img - sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=img.ndim - 1) + sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=get_spatial_ndim(img)) return GaussianSmooth(sigma=sigma, approx=self.approx)(img) @@ -1723,7 +1724,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) gf1, gf2 = ( - GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) + GaussianFilter(get_spatial_ndim(img), sigma, approx=self.approx).to(img_t.device) for sigma in (self.sigma1, self.sigma2) ) blurred_f = gf1(img_t.unsqueeze(0)) @@ -1811,8 +1812,9 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.x2 is None or self.y2 is None or self.z2 is None or self.a is None: raise RuntimeError("please call the `randomize()` function first.") - sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=img.ndim - 1) - sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=img.ndim - 1) + _sp = get_spatial_ndim(img) + sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=_sp) + sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=_sp) return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index ecf918f47a..fdde1684bb 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -213,7 +213,7 @@ def track_transform_meta( orig_affine = data_t.peek_pending_affine() orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0] try: - affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + affine = orig_affine @ to_affine_nd(orig_affine.shape[-1] - 1, affine, dtype=torch.float64) except RuntimeError as e: if orig_affine.ndim > 2: if data_t.is_batch: diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index a33d76807c..d693ba6810 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -256,9 +256,11 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, if not pending: return data, [] + _rank = data.spatial_ndim if isinstance(data, MetaTensor) else 3 + cumulative_xform = affine_from_pending(pending[0]) - if cumulative_xform.shape[0] == 3: - cumulative_xform = to_affine_nd(3, cumulative_xform) + if cumulative_xform.shape[0] < _rank + 1: + cumulative_xform = to_affine_nd(_rank, cumulative_xform) cur_kwargs = kwargs_from_pending(pending[0]) override_kwargs: dict[str, Any] = {} @@ -283,8 +285,8 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, data = resample(data.to(device), cumulative_xform, _cur_kwargs) next_matrix = affine_from_pending(p) - if next_matrix.shape[0] == 3: - next_matrix = to_affine_nd(3, next_matrix) + if next_matrix.shape[0] < _rank + 1: + next_matrix = to_affine_nd(_rank, next_matrix) cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 47623b748d..ac683f5b3e 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -23,7 +23,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.networks import one_hot from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering from monai.transforms.inverse import InvertibleTransform @@ -624,7 +624,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) img_: torch.Tensor = convert_to_tensor(img, track_meta=False) - spatial_dims = len(img_.shape) - 1 + spatial_dims = get_spatial_ndim(img) img_ = img_.unsqueeze(0) # adds a batch dim if spatial_dims == 2: kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32) @@ -1104,7 +1104,7 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor: image_tensor = convert_to_tensor(image, track_meta=get_track_meta()) # Check/set spatial axes - n_spatial_dims = image_tensor.ndim - 1 # excluding the channel dimension + n_spatial_dims = get_spatial_ndim(image_tensor) valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0)) # Check gradient axes to be valid diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index b6bf211cc4..76b7f7d737 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -27,7 +27,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.box_utils import BoxMode, StandardMode from monai.data.meta_obj import get_track_meta, set_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij @@ -848,12 +848,14 @@ def __call__( anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma - input_ndim = img.ndim - 1 # spatial ndim + input_ndim = get_spatial_ndim(img) if self.size_mode == "all": output_ndim = len(ensure_tuple(self.spatial_size)) if output_ndim > input_ndim: input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) img = img.reshape(input_shape) + if isinstance(img, MetaTensor): + img.spatial_ndim = output_ndim elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " @@ -1034,7 +1036,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0] if isinstance(out, MetaTensor): affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False) - mat = to_affine_nd(len(affine) - 1, transform_t) + mat = to_affine_nd(out.spatial_ndim, transform_t) out.affine @= convert_to_dst_type(mat, affine)[0] return out @@ -1131,7 +1133,7 @@ def __call__( during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim + _zoom = ensure_tuple_rep(self.zoom, get_spatial_ndim(img)) _mode = self.mode if mode is None else mode _padding_mode = padding_mode or self.padding_mode _align_corners = self.align_corners if align_corners is None else align_corners @@ -1519,7 +1521,7 @@ def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: return None - self._axis = self.R.randint(data.ndim - 1) + self._axis = self.R.randint(get_spatial_ndim(data)) def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: """ @@ -1629,13 +1631,14 @@ def randomize(self, img: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: return None + _sp = get_spatial_ndim(img) self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] if len(self._zoom) == 1: # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) - elif len(self._zoom) == 2 and img.ndim > 3: + self._zoom = ensure_tuple_rep(self._zoom[0], _sp) + elif len(self._zoom) == 2 and _sp > 2: # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) + self._zoom = ensure_tuple_rep(self._zoom[0], _sp - 1) + ensure_tuple(self._zoom[-1]) def __call__( self, @@ -2350,7 +2353,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: out.meta = data.meta # type: ignore affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] xform, *_ = convert_to_dst_type( - Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine + Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine ) out.affine @= xform return out @@ -2619,7 +2622,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: out.meta = data.meta # type: ignore affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] xform, *_ = convert_to_dst_type( - Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine + Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine ) out.affine @= xform return out @@ -3032,7 +3035,7 @@ def __call__( raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") all_ranges = [] - num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1) + num_cells = ensure_tuple_rep(self.num_cells, get_spatial_ndim(img)) if isinstance(img, MetaTensor) and img.pending_operations: warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.") for dim_idx, dim_size in enumerate(img.shape[1:]): diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 3001dd1e64..a403019751 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -26,7 +26,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.box_utils import get_boxmode from monai.data.meta_obj import get_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.transforms.croppad.array import ResizeWithPadOrCrop @@ -99,7 +99,7 @@ def spatial_resample( src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4) img = convert_to_tensor(data=img, track_meta=get_track_meta()) # ensure spatial rank is <= 3 - spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) + spatial_rank = min(get_spatial_ndim(img), 3) if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None: spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size src_affine = to_affine_nd(spatial_rank, src_affine).to(torch.float64) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3dc7897feb..bd86120350 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -30,7 +30,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps from monai.networks.layers.simplelayers import ( ApplyFilter, @@ -314,23 +314,28 @@ def __call__(self, img: torch.Tensor) -> list[torch.Tensor]: """ Apply the transform to `img`. """ - n_out = img.shape[self.dim] + dim = self.dim if self.dim >= 0 else self.dim + img.ndim + n_out = img.shape[dim] if isinstance(img, torch.Tensor): - outputs = list(torch.split(img, 1, self.dim)) + outputs = list(torch.split(img, 1, dim)) else: - outputs = np.split(img, n_out, self.dim) + outputs = np.split(img, n_out, dim) for idx, item in enumerate(outputs): if not self.keepdim: - outputs[idx] = item.squeeze(self.dim) + outputs[idx] = item.squeeze(dim) if self.update_meta and isinstance(img, MetaTensor): - if not isinstance(item, MetaTensor): - item = MetaTensor(item, meta=img.meta) - if self.dim == 0: # don't update affine if channel dim + out = outputs[idx] + if not isinstance(out, MetaTensor): + out = MetaTensor(out, meta=img.meta) + outputs[idx] = out + if dim == 0: # don't update affine if channel dim continue - ndim = len(item.affine) - shift = torch.eye(ndim, device=item.affine.device, dtype=item.affine.dtype) - shift[self.dim - 1, -1] = idx - item.affine = item.affine @ shift + ndim = len(out.affine) + shift = torch.eye(ndim, device=out.affine.device, dtype=out.affine.dtype) + shift[dim - 1, -1] = idx + out.affine = out.affine @ shift + if not self.keepdim: + out.spatial_ndim = max(1, out.spatial_ndim - 1) return outputs @@ -1506,8 +1511,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Args: img: data to be transformed, assuming `img` is channel first. """ - if max(self.spatial_dims) > img.ndim - 2 or min(self.spatial_dims) < 0: - raise ValueError(f"`spatial_dims` values must be within [0, {img.ndim - 2}]") + _sp = get_spatial_ndim(img) + if max(self.spatial_dims) > _sp - 1 or min(self.spatial_dims) < 0: + raise ValueError(f"`spatial_dims` values must be within [0, {_sp - 1}]") spatial_size = img.shape[1:] coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_size), indexing="ij")) @@ -1675,7 +1681,7 @@ def __call__( applied_operations = img.applied_operations img_, prev_type, device = convert_data_type(img, torch.Tensor) - ndim = img_.ndim - 1 # assumes channel first format + ndim = get_spatial_ndim(img) if isinstance(self.filter, str): self.filter = self._get_filter_from_string(self.filter, self.filter_size, ndim) # type: ignore diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index 427902f784..e9fb834d62 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -68,6 +68,7 @@ def check_ids(self, a, b, should_match): def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: self.assertEqual(a.is_batch, b.is_batch) + self.assertEqual(a.spatial_ndim, b.spatial_ndim) meta_a, meta_b = a.meta, b.meta # need to split affine from rest of metadata aff_a = meta_a.get("affine", None) diff --git a/tests/data/meta_tensor/test_spatial_ndim.py b/tests/data/meta_tensor/test_spatial_ndim.py new file mode 100644 index 0000000000..876d5b2f1e --- /dev/null +++ b/tests/data/meta_tensor/test_spatial_ndim.py @@ -0,0 +1,139 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from copy import deepcopy +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor +from monai.data.utils import collate_meta_tensor_fn, decollate_batch +from monai.transforms import Resize, SqueezeDim +from monai.transforms.utility.array import SplitDim +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +# (shape, affine, expected_spatial_ndim) +CONSTRUCTION_CASES = [ + ((1, 10, 10, 10), None, 3), # default eye(4) + ((1, 10, 10), torch.eye(3), 2), # eye(3) + ((1, 10), torch.eye(2), 1), # eye(2) +] + +# (description, op, expected_spatial_ndim) -- op takes a 2D MetaTensor and returns a new one +PRESERVATION_CASES = [ + ("reshape", lambda t: t.reshape(1, 100), 2), + ("unsqueeze", lambda t: t.unsqueeze(0), 2), + ("squeeze", lambda t: MetaTensor(torch.randn(1, 1, 10, 10), affine=torch.eye(3)).squeeze(1), 2), + ("clone", lambda t: t.clone(), 2), + ("deepcopy", lambda t: deepcopy(t), 2), +] + + +class TestSpatialNdim(unittest.TestCase): + @parameterized.expand(CONSTRUCTION_CASES) + def test_construction(self, shape, affine, expected): + kwargs = {"affine": affine} if affine is not None else {} + t = MetaTensor(torch.randn(*shape), **kwargs) + self.assertEqual(t.spatial_ndim, expected) + + @parameterized.expand(PRESERVATION_CASES) + def test_preserved_through_op(self, _desc, op, expected): + t = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + t2 = op(t) + self.assertEqual(t2.spatial_ndim, expected) + + def test_setter_and_validation(self): + t = MetaTensor(torch.randn(1, 10, 10, 10)) + t.spatial_ndim = 2 + self.assertEqual(t.spatial_ndim, 2) + for bad in (0, -1): + with self.assertRaises(ValueError): + t.spatial_ndim = bad + + def test_affine_setter_syncs(self): + t = MetaTensor(torch.randn(1, 10, 10, 10)) + t.affine = torch.eye(3) + self.assertEqual(t.spatial_ndim, 2) + + def test_copy_from_meta_tensor(self): + t1 = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + self.assertEqual(MetaTensor(t1).spatial_ndim, 2) + + def test_collate_and_decollate(self): + t1 = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + t2 = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + batch = collate_meta_tensor_fn([t1, t2]) + self.assertEqual(batch.spatial_ndim, 2) + for item in decollate_batch(batch): + self.assertIsInstance(item, MetaTensor) + self.assertEqual(item.spatial_ndim, 2) + + def test_derived_properties(self): + """peek_pending_rank, peek_pending_shape, and pixdim all respect spatial_ndim.""" + aff = torch.diag(torch.tensor([2.0, 3.0, 1.0], dtype=torch.float64)) + t = MetaTensor(torch.randn(1, 10, 10), affine=aff) + self.assertEqual(t.peek_pending_rank(), 2) + self.assertEqual(t.peek_pending_shape(), (10, 10)) + self.assertEqual(len(t.pixdim), 2) + + def test_squeeze_dim_transform(self): + t = MetaTensor(torch.randn(1, 10, 1, 10)) + result = SqueezeDim(dim=2)(t) + self.assertEqual(result.spatial_ndim, result.affine.shape[-1] - 1) + + def test_splitdim_channel_dim_no_decrement(self): + t = MetaTensor(torch.randn(3, 8, 7)) + for item in SplitDim(dim=0, keepdim=False)(t): + if isinstance(item, MetaTensor): + self.assertEqual(item.spatial_ndim, 2) + + def test_lazy_apply_pending_2d(self): + """apply_pending uses spatial_ndim for 2D data instead of hardcoded 3.""" + from monai.transforms.lazy.functional import apply_pending + from monai.utils.enums import LazyAttr + + t = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + self.assertEqual(t.spatial_ndim, 2) + # Push a pending 2D affine operation + pending_op = { + LazyAttr.AFFINE: torch.eye(3, dtype=torch.float64), + LazyAttr.SHAPE: (10, 10), + LazyAttr.INTERP_MODE: "bilinear", + LazyAttr.PADDING_MODE: "zeros", + } + t.push_pending_operation(pending_op) + result, applied = apply_pending(t, overrides={"mode": "bilinear"}) + self.assertIsInstance(result, MetaTensor) + self.assertEqual(len(applied), 1) + + @skipUnless(has_einops, "Requires einops") + def test_einops_rearrange_then_resize(self): + """Reproduce the exact #6397 bug: einops.rearrange -> Resize.""" + from einops import rearrange + + x = MetaTensor(torch.randn(1, 1, 64, 64, 3)) + x.is_batch = True + x.meta["affine"] = torch.eye(4)[None] + x_ = rearrange(x, "b c h w d -> (b c) h w d") + self.assertIsInstance(x_, MetaTensor) + self.assertEqual(x_.spatial_ndim, 3) + out = Resize(spatial_size=(32, 32, 3), mode="trilinear", align_corners=True)(x_) + self.assertEqual(out.shape[-3:], (32, 32, 3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/test_squeezedim.py b/tests/transforms/test_squeezedim.py index 5fd333d821..8e838629f4 100644 --- a/tests/transforms/test_squeezedim.py +++ b/tests/transforms/test_squeezedim.py @@ -38,6 +38,7 @@ def test_shape(self, input_param, test_data, expected_shape): self.assertTupleEqual(result.shape, expected_shape) if "dim" in input_param and input_param["dim"] == 2 and isinstance(result, MetaTensor): assert_allclose(result.affine.shape, [3, 3]) + self.assertEqual(result.spatial_ndim, result.affine.shape[-1] - 1) @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): diff --git a/tests/transforms/utility/test_splitdim.py b/tests/transforms/utility/test_splitdim.py index 31d9983a2b..5cf7a08112 100644 --- a/tests/transforms/utility/test_splitdim.py +++ b/tests/transforms/utility/test_splitdim.py @@ -16,6 +16,7 @@ import numpy as np from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms.utility.array import SplitDim from tests.test_utils import TEST_NDARRAYS @@ -47,6 +48,39 @@ def test_singleton(self): out = SplitDim(dim=1)(arr) self.assertEqual(out[0].shape, shape) + def test_spatial_ndim_decremented(self): + """spatial_ndim decremented for keepdim=False on spatial dim.""" + import torch + + arr = MetaTensor(torch.randn(2, 3, 8, 7)) + self.assertEqual(arr.spatial_ndim, 3) + out = SplitDim(dim=1, keepdim=False)(arr) + for item in out: + if isinstance(item, MetaTensor): + self.assertEqual(item.spatial_ndim, 2) + + def test_spatial_ndim_negative_dim(self): + """spatial_ndim decremented for keepdim=False with negative dim.""" + import torch + + arr = MetaTensor(torch.randn(2, 3, 8, 7)) + self.assertEqual(arr.spatial_ndim, 3) + out = SplitDim(dim=-1, keepdim=False)(arr) + for item in out: + if isinstance(item, MetaTensor): + self.assertEqual(item.spatial_ndim, 2) + + def test_spatial_ndim_channel_dim_no_decrement(self): + """spatial_ndim not decremented for keepdim=False on channel dim (dim=0).""" + import torch + + arr = MetaTensor(torch.randn(3, 8, 7)) + self.assertEqual(arr.spatial_ndim, 2) + out = SplitDim(dim=0, keepdim=False)(arr) + for item in out: + if isinstance(item, MetaTensor): + self.assertEqual(item.spatial_ndim, 2) + if __name__ == "__main__": unittest.main() From c52a14914c24c5385b1d48cdae5fab12aeeb591e Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Tue, 3 Mar 2026 19:30:45 +0000 Subject: [PATCH 2/8] address coderabbit Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/spatial/functional.py | 5 +++-- tests/data/meta_tensor/test_spatial_ndim.py | 2 +- tests/transforms/utility/test_splitdim.py | 12 ++++++------ 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index a403019751..6561e23480 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -99,9 +99,10 @@ def spatial_resample( src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4) img = convert_to_tensor(data=img, track_meta=get_track_meta()) # ensure spatial rank is <= 3 - spatial_rank = min(get_spatial_ndim(img), 3) + max_rank = max(int(img.ndim) - 1, 1) + spatial_rank = min(get_spatial_ndim(img), max_rank, 3) if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None: - spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size + spatial_rank = min(len(ensure_tuple(spatial_size)), max_rank, 3) # infer spatial rank based on spatial_size src_affine = to_affine_nd(spatial_rank, src_affine).to(torch.float64) dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine dst_affine = convert_to_dst_type(dst_affine, src_affine)[0] diff --git a/tests/data/meta_tensor/test_spatial_ndim.py b/tests/data/meta_tensor/test_spatial_ndim.py index 876d5b2f1e..396bc5a111 100644 --- a/tests/data/meta_tensor/test_spatial_ndim.py +++ b/tests/data/meta_tensor/test_spatial_ndim.py @@ -37,7 +37,7 @@ PRESERVATION_CASES = [ ("reshape", lambda t: t.reshape(1, 100), 2), ("unsqueeze", lambda t: t.unsqueeze(0), 2), - ("squeeze", lambda t: MetaTensor(torch.randn(1, 1, 10, 10), affine=torch.eye(3)).squeeze(1), 2), + ("squeeze", lambda t: t.unsqueeze(1).squeeze(1), 2), ("clone", lambda t: t.clone(), 2), ("deepcopy", lambda t: deepcopy(t), 2), ] diff --git a/tests/transforms/utility/test_splitdim.py b/tests/transforms/utility/test_splitdim.py index 5cf7a08112..b15afb24f2 100644 --- a/tests/transforms/utility/test_splitdim.py +++ b/tests/transforms/utility/test_splitdim.py @@ -56,8 +56,8 @@ def test_spatial_ndim_decremented(self): self.assertEqual(arr.spatial_ndim, 3) out = SplitDim(dim=1, keepdim=False)(arr) for item in out: - if isinstance(item, MetaTensor): - self.assertEqual(item.spatial_ndim, 2) + self.assertIsInstance(item, MetaTensor) + self.assertEqual(item.spatial_ndim, 2) def test_spatial_ndim_negative_dim(self): """spatial_ndim decremented for keepdim=False with negative dim.""" @@ -67,8 +67,8 @@ def test_spatial_ndim_negative_dim(self): self.assertEqual(arr.spatial_ndim, 3) out = SplitDim(dim=-1, keepdim=False)(arr) for item in out: - if isinstance(item, MetaTensor): - self.assertEqual(item.spatial_ndim, 2) + self.assertIsInstance(item, MetaTensor) + self.assertEqual(item.spatial_ndim, 2) def test_spatial_ndim_channel_dim_no_decrement(self): """spatial_ndim not decremented for keepdim=False on channel dim (dim=0).""" @@ -78,8 +78,8 @@ def test_spatial_ndim_channel_dim_no_decrement(self): self.assertEqual(arr.spatial_ndim, 2) out = SplitDim(dim=0, keepdim=False)(arr) for item in out: - if isinstance(item, MetaTensor): - self.assertEqual(item.spatial_ndim, 2) + self.assertIsInstance(item, MetaTensor) + self.assertEqual(item.spatial_ndim, 2) if __name__ == "__main__": From ea915cb884868b950ed7b01f1501a7c70c332de5 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Wed, 4 Mar 2026 06:50:59 +0000 Subject: [PATCH 3/8] ci: retrigger CI checks Signed-off-by: Soumya Snigdha Kundu From e50ae410b85d2cab326d283986fdf47cc7dfa77d Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Wed, 4 Mar 2026 07:25:45 +0000 Subject: [PATCH 4/8] Fix 2D inverse transform shape mismatch (4x4 vs 3x3) Use affine matrix rank instead of spatial_ndim in Rotate, Affine, and RandAffine inverse methods to avoid RuntimeError when spatial_ndim=2 but the stored affine is 4x4. Add regression tests for all three. Signed-off-by: Soumya Snigdha Kundu --- monai/transforms/spatial/array.py | 6 ++-- tests/data/meta_tensor/test_spatial_ndim.py | 34 ++++++++++++++++++++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 76b7f7d737..fdbf37bfa3 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1036,7 +1036,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0] if isinstance(out, MetaTensor): affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False) - mat = to_affine_nd(out.spatial_ndim, transform_t) + mat = to_affine_nd(len(affine) - 1, transform_t) out.affine @= convert_to_dst_type(mat, affine)[0] return out @@ -2353,7 +2353,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: out.meta = data.meta # type: ignore affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] xform, *_ = convert_to_dst_type( - Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine + Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine ) out.affine @= xform return out @@ -2622,7 +2622,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: out.meta = data.meta # type: ignore affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] xform, *_ = convert_to_dst_type( - Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine + Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine ) out.affine @= xform return out diff --git a/tests/data/meta_tensor/test_spatial_ndim.py b/tests/data/meta_tensor/test_spatial_ndim.py index 396bc5a111..6fad34e1bd 100644 --- a/tests/data/meta_tensor/test_spatial_ndim.py +++ b/tests/data/meta_tensor/test_spatial_ndim.py @@ -15,12 +15,13 @@ from copy import deepcopy from unittest import skipUnless +import numpy as np import torch from parameterized import parameterized from monai.data import MetaTensor from monai.data.utils import collate_meta_tensor_fn, decollate_batch -from monai.transforms import Resize, SqueezeDim +from monai.transforms import Affine, RandAffine, Resize, Rotate, SqueezeDim from monai.transforms.utility.array import SplitDim from monai.utils import optional_import @@ -134,6 +135,37 @@ def test_einops_rearrange_then_resize(self): out = Resize(spatial_size=(32, 32, 3), mode="trilinear", align_corners=True)(x_) self.assertEqual(out.shape[-3:], (32, 32, 3)) + def test_affine_inverse_2d_metatensor(self): + """Affine.inverse on 2D data: 4x4 affine with spatial_ndim=2.""" + img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4)) + self.assertEqual(img.spatial_ndim, 2) + xform = Affine(rotate_params=(np.pi / 6,), padding_mode="zeros", image_only=True) + result = xform(img) + inv = xform.inverse(result) + self.assertEqual(inv.shape, img.shape) + self.assertEqual(len(inv.applied_operations), 0) + + def test_rotate_inverse_2d_metatensor(self): + """Rotate.inverse on 2D data: 4x4 affine with spatial_ndim=2.""" + img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4)) + self.assertEqual(img.spatial_ndim, 2) + xform = Rotate(angle=(np.pi / 4,), padding_mode="zeros") + result = xform(img) + inv = xform.inverse(result) + self.assertEqual(inv.shape, img.shape) + self.assertEqual(len(inv.applied_operations), 0) + + def test_rand_affine_inverse_2d_metatensor(self): + """RandAffine.inverse on 2D data: 4x4 affine with spatial_ndim=2.""" + img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4)) + self.assertEqual(img.spatial_ndim, 2) + xform = RandAffine(prob=1.0, rotate_range=(np.pi / 6,), padding_mode="zeros") + xform.set_random_state(seed=42) + result = xform(img) + inv = xform.inverse(result) + self.assertEqual(inv.shape, img.shape) + self.assertEqual(len(inv.applied_operations), 0) + if __name__ == "__main__": unittest.main() From 787eef4f2f6863427448baf6a0959a687e053d75 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Wed, 4 Mar 2026 10:32:26 +0000 Subject: [PATCH 5/8] Fix spatial_ndim drift for sliced MetaTensor 2D paths Signed-off-by: Soumya Snigdha Kundu --- monai/data/meta_tensor.py | 21 +++++++++++++--- monai/data/utils.py | 2 +- tests/data/meta_tensor/test_spatial_ndim.py | 28 +++++++++++++++++++-- 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index e6de8d1a1e..afeea1c171 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -31,6 +31,11 @@ __all__ = ["MetaTensor", "get_spatial_ndim"] +def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int: + """Clamp spatial dims to a valid range for the current tensor shape.""" + return max(1, min(int(spatial_ndim), max(int(tensor_ndim) - 1, 1))) + + def get_spatial_ndim(img: NdarrayOrTensor) -> int: """Return the number of spatial dimensions assuming channel-first layout. @@ -38,7 +43,13 @@ def get_spatial_ndim(img: NdarrayOrTensor) -> int: ``img.ndim - 1``. """ if isinstance(img, MetaTensor): - return img.spatial_ndim + inferred = _normalize_spatial_ndim(img.spatial_ndim, img.ndim) + shape_spatial = max(img.ndim - 1, 1) + # For non-batched tensors, preserve explicit higher-rank shape information + # (e.g., invalid 4D spatial inputs should still be reported as rank 4). + if not img.is_batch and shape_spatial > inferred: + return shape_spatial + return inferred return img.ndim - 1 @@ -175,9 +186,9 @@ def __init__( self.affine = self.get_default_affine() # derive spatial_ndim from affine, clamped by tensor shape if spatial_ndim is not None: - self.spatial_ndim = spatial_ndim + self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim) elif self.affine.ndim == 2: - self.spatial_ndim = min(self.affine.shape[-1] - 1, max(self.ndim - 1, 1)) + self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim) # applied_operations if applied_operations is not None: @@ -243,6 +254,8 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") if is_batch: ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs) + if func == torch.Tensor.__getitem__: + ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim) out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out @@ -492,7 +505,7 @@ def affine(self, d: NdarrayTensor) -> None: a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) self.meta[MetaKeys.AFFINE] = a if a.ndim == 2: # non-batched: sync spatial_ndim - self.spatial_ndim = a.shape[-1] - 1 + self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim) @property def spatial_ndim(self) -> int: diff --git a/monai/data/utils.py b/monai/data/utils.py index 35d19e59ee..64c98a8752 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -432,7 +432,7 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None): collated.meta = default_collate(meta_dicts) collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] collated.is_batch = True - collated.spatial_ndim = getattr(batch[0], "spatial_ndim", 3) # assumes uniform spatial_ndim + collated.spatial_ndim = min(getattr(batch[0], "spatial_ndim", 3), max(collated.ndim - 1, 1)) return collated diff --git a/tests/data/meta_tensor/test_spatial_ndim.py b/tests/data/meta_tensor/test_spatial_ndim.py index 6fad34e1bd..e5a9d17c61 100644 --- a/tests/data/meta_tensor/test_spatial_ndim.py +++ b/tests/data/meta_tensor/test_spatial_ndim.py @@ -19,9 +19,9 @@ import torch from parameterized import parameterized -from monai.data import MetaTensor +from monai.data import MetaTensor, get_spatial_ndim from monai.data.utils import collate_meta_tensor_fn, decollate_batch -from monai.transforms import Affine, RandAffine, Resize, Rotate, SqueezeDim +from monai.transforms import Affine, LabelToContour, RandAffine, RandZoom, Resize, Rotate, SqueezeDim from monai.transforms.utility.array import SplitDim from monai.utils import optional_import @@ -121,6 +121,30 @@ def test_lazy_apply_pending_2d(self): self.assertIsInstance(result, MetaTensor) self.assertEqual(len(applied), 1) + def test_batch_slice_clamps_spatial_ndim(self): + t = MetaTensor(torch.randn(10, 6, 5, 7), affine=torch.eye(4)) + self.assertEqual(t.spatial_ndim, 3) + sliced = t[0] + self.assertEqual(sliced.shape, (6, 5, 7)) + self.assertEqual(sliced.spatial_ndim, 2) + self.assertEqual(get_spatial_ndim(sliced), 2) + + def test_label_to_contour_batch_slice_2d(self): + t = MetaTensor(torch.randint(0, 2, (10, 6, 5, 7)).float(), affine=torch.eye(4)) + sliced = t[0] + out = LabelToContour()(sliced) + self.assertEqual(out.shape, sliced.shape) + + def test_rand_zoom_batch_slice_2d(self): + t = MetaTensor(torch.randn(10, 1, 64, 64), affine=torch.eye(4)) + sliced = t[0] + zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=1.2) + zoom.set_random_state(seed=0) + zoom.randomize(sliced) + self.assertEqual(len(zoom._zoom), 2) + out = zoom(sliced) + self.assertEqual(out.ndim, sliced.ndim) + @skipUnless(has_einops, "Requires einops") def test_einops_rearrange_then_resize(self): """Reproduce the exact #6397 bug: einops.rearrange -> Resize.""" From 9f359b08cbbc459e0e36063b6e88b2ed80ba43f2 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Wed, 4 Mar 2026 11:39:31 +0000 Subject: [PATCH 6/8] Fix MetaTensor spatial_ndim propagation regressions - Clamp spatial_ndim only for true batch-only indexing - Handle explicit no-channel metadata when normalizing rank - Remove SplitDim double-decrement after affine sync - Align batch-slice tests with batched MetaTensor metadata - Extract DEFAULT_SPATIAL_NDIM constant to eliminate magic numbers - Add documentation explaining spatial_ndim caching and affine sync Signed-off-by: Soumya Snigdha Kundu --- monai/data/meta_obj.py | 3 + monai/data/meta_tensor.py | 72 ++++++++++++++------- monai/data/utils.py | 4 +- monai/transforms/post/array.py | 4 ++ monai/transforms/spatial/array.py | 12 +++- monai/transforms/utility/array.py | 2 - tests/data/meta_tensor/test_spatial_ndim.py | 6 ++ 7 files changed, 75 insertions(+), 28 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index d903e68d0b..df1bc71334 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -24,6 +24,9 @@ _TRACK_META = True +# Default number of spatial dimensions for medical imaging (3D volumetric data) +_DEFAULT_SPATIAL_NDIM = 3 + __all__ = ["get_track_meta", "set_track_meta", "MetaObj"] diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index afeea1c171..d44dfc362b 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,7 +13,7 @@ import functools import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any @@ -22,8 +22,8 @@ import monai from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor -from monai.data.meta_obj import MetaObj, get_track_meta -from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata +from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta +from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor @@ -31,9 +31,18 @@ __all__ = ["MetaTensor", "get_spatial_ndim"] -def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int: +def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int: """Clamp spatial dims to a valid range for the current tensor shape.""" - return max(1, min(int(spatial_ndim), max(int(tensor_ndim) - 1, 1))) + limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1) + return max(1, min(int(spatial_ndim), limit)) + + +def _has_explicit_no_channel(meta: Mapping | None) -> bool: + return ( + isinstance(meta, Mapping) + and MetaKeys.ORIGINAL_CHANNEL_DIM in meta + and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM]) + ) def get_spatial_ndim(img: NdarrayOrTensor) -> int: @@ -43,16 +52,22 @@ def get_spatial_ndim(img: NdarrayOrTensor) -> int: ``img.ndim - 1``. """ if isinstance(img, MetaTensor): - inferred = _normalize_spatial_ndim(img.spatial_ndim, img.ndim) - shape_spatial = max(img.ndim - 1, 1) - # For non-batched tensors, preserve explicit higher-rank shape information - # (e.g., invalid 4D spatial inputs should still be reported as rank 4). - if not img.is_batch and shape_spatial > inferred: - return shape_spatial - return inferred + no_channel = _has_explicit_no_channel(img.meta) + return _normalize_spatial_ndim(img.spatial_ndim, img.ndim, no_channel=no_channel) return img.ndim - 1 +def _is_batch_only_index(index: Any) -> bool: + """True when indexing pattern selects only the batch axis (e.g., ``x[0]`` or ``x[0, ...]``).""" + if isinstance(index, (int, np.integer)): + return True + if not isinstance(index, Sequence) or not index: + return False + if not isinstance(index[0], (int, np.integer)): + return False + return all(i in (slice(None, None, None), Ellipsis, None) for i in index[1:]) + + @functools.lru_cache(None) def _get_named_tuple_like_type(func): if ( @@ -184,11 +199,13 @@ def __init__( self.affine = self.meta[MetaKeys.AFFINE] else: self.affine = self.get_default_affine() - # derive spatial_ndim from affine, clamped by tensor shape + # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape. + # This cached value is kept in sync via the affine setter for hot-path performance. + no_channel = _has_explicit_no_channel(self.meta) if spatial_ndim is not None: - self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim) + self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel) elif self.affine.ndim == 2: - self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim) + self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel) # applied_operations if applied_operations is not None: @@ -254,8 +271,6 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") if is_batch: ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs) - if func == torch.Tensor.__getitem__: - ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim) out.append(ret) # if the input was a tuple, then return it as a tuple return tuple(out) if isinstance(rets, tuple) else out @@ -271,6 +286,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): if func == torch.Tensor.__getitem__: if idx > 0 or len(args) < 2 or len(args[0]) < 1: return ret + full_idx = args[1] batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, @@ -292,6 +308,8 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): ret_meta.is_batch = False if hasattr(ret_meta, "__dict__"): ret.__dict__ = ret_meta.__dict__.copy() + if _is_batch_only_index(full_idx): + ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim, no_channel=False) # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th dimension. elif func == torch.Tensor.unbind: @@ -501,16 +519,26 @@ def affine(self) -> torch.Tensor: @affine.setter def affine(self, d: NdarrayTensor) -> None: - """Set the affine.""" + """Set the affine. + + When setting a non-batched affine matrix, automatically synchronizes the cached + spatial_ndim attribute to maintain consistency between the affine matrix (source of truth) + and the cached spatial dimension count. + """ a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) self.meta[MetaKeys.AFFINE] = a - if a.ndim == 2: # non-batched: sync spatial_ndim - self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim) + if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth) + no_channel = _has_explicit_no_channel(self.meta) + self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel) @property def spatial_ndim(self) -> int: - """Get the number of spatial dimensions.""" - return getattr(self, "_spatial_ndim", 3) + """Get the number of spatial dimensions. + + This value is cached for hot-path performance and is kept in sync with the affine matrix + via the affine setter. The affine matrix is the source of truth for spatial dimensions. + """ + return getattr(self, "_spatial_ndim", _DEFAULT_SPATIAL_NDIM) @spatial_ndim.setter def spatial_ndim(self, val: int) -> None: diff --git a/monai/data/utils.py b/monai/data/utils.py index 64c98a8752..f14c1b85c1 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -31,7 +31,7 @@ from torch.utils.data._utils.collate import default_collate from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike -from monai.data.meta_obj import MetaObj +from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj from monai.utils import ( MAX_SEED, BlendMode, @@ -432,7 +432,7 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None): collated.meta = default_collate(meta_dicts) collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] collated.is_batch = True - collated.spatial_ndim = min(getattr(batch[0], "spatial_ndim", 3), max(collated.ndim - 1, 1)) + collated.spatial_ndim = min(getattr(batch[0], "spatial_ndim", _DEFAULT_SPATIAL_NDIM), max(collated.ndim - 1, 1)) return collated diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index ac683f5b3e..3b5d38cf52 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -625,6 +625,10 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) img_: torch.Tensor = convert_to_tensor(img, track_meta=False) spatial_dims = get_spatial_ndim(img) + # Validate actual tensor shape against tracked spatial_ndim + actual_spatial = img_.ndim - 1 # channel-first layout + if actual_spatial != spatial_dims: + spatial_dims = actual_spatial img_ = img_.unsqueeze(0) # adds a batch dim if spatial_dims == 2: kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index fdbf37bfa3..8c228dbb0a 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1036,6 +1036,9 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0] if isinstance(out, MetaTensor): affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False) + # Use affine matrix shape directly (not spatial_ndim) because the affine may be + # larger than the spatial dimensions (e.g., 4x4 for 2D data), and we need to match + # the actual affine matrix rank being composed mat = to_affine_nd(len(affine) - 1, transform_t) out.affine @= convert_to_dst_type(mat, affine)[0] return out @@ -2352,6 +2355,8 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: out = MetaTensor(out) out.meta = data.meta # type: ignore affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] + # Use affine matrix shape directly (not spatial_ndim) to ensure matrix composition compatibility + # when affine is larger than spatial dimensions (e.g., 4x4 for 2D data) xform, *_ = convert_to_dst_type( Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine ) @@ -2621,6 +2626,8 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: out = MetaTensor(out) out.meta = data.meta # type: ignore affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] + # Use affine matrix shape directly (not spatial_ndim) to ensure matrix composition compatibility + # when affine is larger than spatial dimensions (e.g., 4x4 for 2D data) xform, *_ = convert_to_dst_type( Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine ) @@ -3035,10 +3042,11 @@ def __call__( raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") all_ranges = [] - num_cells = ensure_tuple_rep(self.num_cells, get_spatial_ndim(img)) + _sp = get_spatial_ndim(img) + num_cells = ensure_tuple_rep(self.num_cells, _sp) if isinstance(img, MetaTensor) and img.pending_operations: warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.") - for dim_idx, dim_size in enumerate(img.shape[1:]): + for dim_idx, dim_size in enumerate(img.shape[1 : 1 + _sp]): dim_distort_steps = distort_steps[dim_idx] ranges = torch.zeros(dim_size, dtype=torch.float32) cell_size = dim_size // num_cells[dim_idx] diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bd86120350..9919b9a6eb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -334,8 +334,6 @@ def __call__(self, img: torch.Tensor) -> list[torch.Tensor]: shift = torch.eye(ndim, device=out.affine.device, dtype=out.affine.dtype) shift[dim - 1, -1] = idx out.affine = out.affine @ shift - if not self.keepdim: - out.spatial_ndim = max(1, out.spatial_ndim - 1) return outputs diff --git a/tests/data/meta_tensor/test_spatial_ndim.py b/tests/data/meta_tensor/test_spatial_ndim.py index e5a9d17c61..d7ef9bd9f7 100644 --- a/tests/data/meta_tensor/test_spatial_ndim.py +++ b/tests/data/meta_tensor/test_spatial_ndim.py @@ -123,6 +123,8 @@ def test_lazy_apply_pending_2d(self): def test_batch_slice_clamps_spatial_ndim(self): t = MetaTensor(torch.randn(10, 6, 5, 7), affine=torch.eye(4)) + t.is_batch = True + t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1) self.assertEqual(t.spatial_ndim, 3) sliced = t[0] self.assertEqual(sliced.shape, (6, 5, 7)) @@ -131,12 +133,16 @@ def test_batch_slice_clamps_spatial_ndim(self): def test_label_to_contour_batch_slice_2d(self): t = MetaTensor(torch.randint(0, 2, (10, 6, 5, 7)).float(), affine=torch.eye(4)) + t.is_batch = True + t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1) sliced = t[0] out = LabelToContour()(sliced) self.assertEqual(out.shape, sliced.shape) def test_rand_zoom_batch_slice_2d(self): t = MetaTensor(torch.randn(10, 1, 64, 64), affine=torch.eye(4)) + t.is_batch = True + t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1) sliced = t[0] zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=1.2) zoom.set_random_state(seed=0) From dc16dda65f7d973fc2c9e592cc647ee75baf3afe Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Wed, 4 Mar 2026 14:40:00 +0000 Subject: [PATCH 7/8] Fix 2D inverse transform failures by removing no_channel from spatial_ndim normalization After EnsureChannelFirst adds a channel dim, ORIGINAL_CHANNEL_DIM="no_channel" refers to the original file, not the current tensor. The no_channel flag caused _normalize_spatial_ndim to treat all dims as spatial (returning 3 instead of 2), breaking Resized_2d, Resized_longest_2d, and Zoomd_2d inverse transforms. Signed-off-by: Soumya Snigdha Kundu --- monai/data/meta_tensor.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index d44dfc362b..a76b37a4c1 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -23,7 +23,7 @@ import monai from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta -from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata +from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor @@ -31,20 +31,12 @@ __all__ = ["MetaTensor", "get_spatial_ndim"] -def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int: +def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int: """Clamp spatial dims to a valid range for the current tensor shape.""" - limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1) + limit = max(int(tensor_ndim) - 1, 1) return max(1, min(int(spatial_ndim), limit)) -def _has_explicit_no_channel(meta: Mapping | None) -> bool: - return ( - isinstance(meta, Mapping) - and MetaKeys.ORIGINAL_CHANNEL_DIM in meta - and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM]) - ) - - def get_spatial_ndim(img: NdarrayOrTensor) -> int: """Return the number of spatial dimensions assuming channel-first layout. @@ -52,8 +44,7 @@ def get_spatial_ndim(img: NdarrayOrTensor) -> int: ``img.ndim - 1``. """ if isinstance(img, MetaTensor): - no_channel = _has_explicit_no_channel(img.meta) - return _normalize_spatial_ndim(img.spatial_ndim, img.ndim, no_channel=no_channel) + return _normalize_spatial_ndim(img.spatial_ndim, img.ndim) return img.ndim - 1 @@ -201,11 +192,10 @@ def __init__( self.affine = self.get_default_affine() # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape. # This cached value is kept in sync via the affine setter for hot-path performance. - no_channel = _has_explicit_no_channel(self.meta) if spatial_ndim is not None: - self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel) + self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim) elif self.affine.ndim == 2: - self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel) + self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim) # applied_operations if applied_operations is not None: @@ -309,7 +299,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): if hasattr(ret_meta, "__dict__"): ret.__dict__ = ret_meta.__dict__.copy() if _is_batch_only_index(full_idx): - ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim, no_channel=False) + ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim) # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th dimension. elif func == torch.Tensor.unbind: @@ -528,8 +518,7 @@ def affine(self, d: NdarrayTensor) -> None: a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) self.meta[MetaKeys.AFFINE] = a if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth) - no_channel = _has_explicit_no_channel(self.meta) - self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel) + self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim) @property def spatial_ndim(self) -> int: From ab2be3a8ad7a2dfae080a0a1583ad9e5613f0607 Mon Sep 17 00:00:00 2001 From: Soumya Snigdha Kundu Date: Wed, 4 Mar 2026 15:36:23 +0000 Subject: [PATCH 8/8] Fix get_spatial_ndim for 2D post-EnsureChannelFirst without regressing 3D The previous commit removed no_channel from all call sites, which broke 3D no-channel tensors (spatial_ndim clamped to ndim-1=2 instead of 3). Targeted fix: only remove no_channel from get_spatial_ndim, since it runs after EnsureChannelFirst has already added a channel dim. The constructor and affine setter keep no_channel to correctly handle pre-channel tensors. Signed-off-by: Soumya Snigdha Kundu --- monai/data/meta_tensor.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index a76b37a4c1..cd6de98613 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -23,7 +23,7 @@ import monai from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta -from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata +from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor @@ -31,17 +31,26 @@ __all__ = ["MetaTensor", "get_spatial_ndim"] -def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int: +def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int: """Clamp spatial dims to a valid range for the current tensor shape.""" - limit = max(int(tensor_ndim) - 1, 1) + limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1) return max(1, min(int(spatial_ndim), limit)) +def _has_explicit_no_channel(meta: Mapping | None) -> bool: + return ( + isinstance(meta, Mapping) + and MetaKeys.ORIGINAL_CHANNEL_DIM in meta + and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM]) + ) + + def get_spatial_ndim(img: NdarrayOrTensor) -> int: """Return the number of spatial dimensions assuming channel-first layout. Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to - ``img.ndim - 1``. + ``img.ndim - 1``. Always assumes channel-first (``no_channel=False``) + because callers run after ``EnsureChannelFirst`` has already added one. """ if isinstance(img, MetaTensor): return _normalize_spatial_ndim(img.spatial_ndim, img.ndim) @@ -192,10 +201,11 @@ def __init__( self.affine = self.get_default_affine() # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape. # This cached value is kept in sync via the affine setter for hot-path performance. + no_channel = _has_explicit_no_channel(self.meta) if spatial_ndim is not None: - self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim) + self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel) elif self.affine.ndim == 2: - self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim) + self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel) # applied_operations if applied_operations is not None: @@ -518,7 +528,8 @@ def affine(self, d: NdarrayTensor) -> None: a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) self.meta[MetaKeys.AFFINE] = a if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth) - self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim) + no_channel = _has_explicit_no_channel(self.meta) + self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel) @property def spatial_ndim(self) -> int: