diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 5e367cc297..fe59e6eb8a 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -71,7 +71,7 @@ monai_to_itk_ddf, ) from .meta_obj import MetaObj, get_track_meta, set_track_meta -from .meta_tensor import MetaTensor +from .meta_tensor import MetaTensor, get_spatial_ndim from .samplers import DistributedSampler, DistributedWeightedRandomSampler from .synthetic import create_test_image_2d, create_test_image_3d from .test_time_augmentation import TestTimeAugmentation diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 15e6e8be15..df1bc71334 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -24,6 +24,9 @@ _TRACK_META = True +# Default number of spatial dimensions for medical imaging (3D volumetric data) +_DEFAULT_SPATIAL_NDIM = 3 + __all__ = ["get_track_meta", "set_track_meta", "MetaObj"] @@ -84,6 +87,7 @@ def __init__(self) -> None: self._applied_operations: list = MetaObj.get_default_applied_operations() self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops self._is_batch: bool = False + self._spatial_ndim: int = 3 # default: 3 spatial dimensions @staticmethod def flatten_meta_objs(*args: Iterable): diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 12bd76ba60..cd6de98613 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,7 +13,7 @@ import functools import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any @@ -21,14 +21,51 @@ import torch import monai -from monai.config.type_definitions import NdarrayTensor -from monai.data.meta_obj import MetaObj, get_track_meta -from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta +from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor -__all__ = ["MetaTensor"] +__all__ = ["MetaTensor", "get_spatial_ndim"] + + +def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int: + """Clamp spatial dims to a valid range for the current tensor shape.""" + limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1) + return max(1, min(int(spatial_ndim), limit)) + + +def _has_explicit_no_channel(meta: Mapping | None) -> bool: + return ( + isinstance(meta, Mapping) + and MetaKeys.ORIGINAL_CHANNEL_DIM in meta + and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM]) + ) + + +def get_spatial_ndim(img: NdarrayOrTensor) -> int: + """Return the number of spatial dimensions assuming channel-first layout. + + Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to + ``img.ndim - 1``. Always assumes channel-first (``no_channel=False``) + because callers run after ``EnsureChannelFirst`` has already added one. + """ + if isinstance(img, MetaTensor): + return _normalize_spatial_ndim(img.spatial_ndim, img.ndim) + return img.ndim - 1 + + +def _is_batch_only_index(index: Any) -> bool: + """True when indexing pattern selects only the batch axis (e.g., ``x[0]`` or ``x[0, ...]``).""" + if isinstance(index, (int, np.integer)): + return True + if not isinstance(index, Sequence) or not index: + return False + if not isinstance(index[0], (int, np.integer)): + return False + return all(i in (slice(None, None, None), Ellipsis, None) for i in index[1:]) @functools.lru_cache(None) @@ -111,6 +148,7 @@ def __new__( meta: dict | None = None, applied_operations: list | None = None, *args, + spatial_ndim: int | None = None, **kwargs, ) -> MetaTensor: _kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {} @@ -123,6 +161,7 @@ def __init__( meta: dict | None = None, applied_operations: list | None = None, *_args, + spatial_ndim: int | None = None, **_kwargs, ) -> None: """ @@ -134,6 +173,8 @@ def __init__( the list is typically maintained by `monai.transforms.TraceableTransform`. See also: :py:class:`monai.transforms.TraceableTransform` _args: additional args (currently not in use in this constructor). + spatial_ndim: optional number of spatial dimensions. If ``None``, derived + from the affine matrix clamped by the tensor shape. _kwargs: additional kwargs (currently not in use in this constructor). Note: @@ -158,6 +199,14 @@ def __init__( self.affine = self.meta[MetaKeys.AFFINE] else: self.affine = self.get_default_affine() + # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape. + # This cached value is kept in sync via the affine setter for hot-path performance. + no_channel = _has_explicit_no_channel(self.meta) + if spatial_ndim is not None: + self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel) + elif self.affine.ndim == 2: + self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel) + # applied_operations if applied_operations is not None: self.applied_operations = applied_operations @@ -237,6 +286,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): if func == torch.Tensor.__getitem__: if idx > 0 or len(args) < 2 or len(args[0]) < 1: return ret + full_idx = args[1] batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, @@ -258,6 +308,8 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): ret_meta.is_batch = False if hasattr(ret_meta, "__dict__"): ret.__dict__ = ret_meta.__dict__.copy() + if _is_batch_only_index(full_idx): + ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim) # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. # But we only want to split the batch if the `unbind` is along the 0th dimension. elif func == torch.Tensor.unbind: @@ -467,15 +519,40 @@ def affine(self) -> torch.Tensor: @affine.setter def affine(self, d: NdarrayTensor) -> None: - """Set the affine.""" - self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) + """Set the affine. + + When setting a non-batched affine matrix, automatically synchronizes the cached + spatial_ndim attribute to maintain consistency between the affine matrix (source of truth) + and the cached spatial dimension count. + """ + a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) + self.meta[MetaKeys.AFFINE] = a + if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth) + no_channel = _has_explicit_no_channel(self.meta) + self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel) + + @property + def spatial_ndim(self) -> int: + """Get the number of spatial dimensions. + + This value is cached for hot-path performance and is kept in sync with the affine matrix + via the affine setter. The affine matrix is the source of truth for spatial dimensions. + """ + return getattr(self, "_spatial_ndim", _DEFAULT_SPATIAL_NDIM) + + @spatial_ndim.setter + def spatial_ndim(self, val: int) -> None: + """Set the number of spatial dimensions.""" + if val < 1: + raise ValueError(f"spatial_ndim must be >= 1, got {val}") + self._spatial_ndim = val @property def pixdim(self): """Get the spacing""" if self.is_batch: - return [affine_to_spacing(a) for a in self.affine] - return affine_to_spacing(self.affine) + return [affine_to_spacing(a, r=self.spatial_ndim) for a in self.affine] + return affine_to_spacing(self.affine, r=self.spatial_ndim) def peek_pending_shape(self): """ @@ -490,7 +567,7 @@ def peek_pending_shape(self): def peek_pending_affine(self): res = self.affine - r = len(res) - 1 + r = res.shape[-1] - 1 if res.ndim >= 2 else self.spatial_ndim if r not in (2, 3): warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.") for p in self.pending_operations: @@ -503,8 +580,10 @@ def peek_pending_affine(self): return res def peek_pending_rank(self): - a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine - return 1 if a is None else int(max(1, len(a) - 1)) + if self.pending_operations: + a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) + return 1 if a is None else int(max(1, len(a) - 1)) + return self.spatial_ndim def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override] """ diff --git a/monai/data/utils.py b/monai/data/utils.py index 4e5a3bd7f6..f14c1b85c1 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -31,7 +31,7 @@ from torch.utils.data._utils.collate import default_collate from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike -from monai.data.meta_obj import MetaObj +from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj from monai.utils import ( MAX_SEED, BlendMode, @@ -432,6 +432,7 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None): collated.meta = default_collate(meta_dicts) collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] collated.is_batch = True + collated.spatial_ndim = min(getattr(batch[0], "spatial_ndim", _DEFAULT_SPATIAL_NDIM), max(collated.ndim - 1, 1)) return collated diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py index acf42849d3..378f1cf688 100644 --- a/monai/transforms/croppad/functional.py +++ b/monai/transforms/croppad/functional.py @@ -22,7 +22,7 @@ from monai.config.type_definitions import NdarrayTensor from monai.data.meta_obj import get_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.data.utils import to_affine_nd from monai.transforms.inverse import TraceableTransform from monai.transforms.utils import convert_pad_mode, create_translate @@ -132,7 +132,7 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int, mode: the padding mode. kwargs: other arguments for the `np.pad` or `torch.pad` function. """ - ndim = len(img.shape) - 1 + ndim = get_spatial_ndim(img) matrix_np = np.round(to_affine_nd(ndim, convert_to_numpy(translation_mat, wrap_sequence=True).copy())) matrix_np = to_affine_nd(len(spatial_size), matrix_np) cc = np.asarray(np.meshgrid(*[[0.5, x - 0.5] for x in spatial_size], indexing="ij")) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 0421d34492..6adce0b49e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -26,6 +26,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor from monai.data.meta_obj import get_track_meta +from monai.data.meta_tensor import get_spatial_ndim from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap from monai.data.utils import get_random_patch, get_valid_patch_size from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter @@ -1580,7 +1581,7 @@ def __init__(self, radius: Sequence[int] | int = 1) -> None: def __call__(self, img: NdarrayTensor) -> NdarrayTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) - spatial_dims = img_t.ndim - 1 + spatial_dims = get_spatial_ndim(img) r = ensure_tuple_rep(self.radius, spatial_dims) median_filter_instance = MedianFilter(r, spatial_dims=spatial_dims) out_t: torch.Tensor = median_filter_instance(img_t) @@ -1616,7 +1617,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma] else: sigma = torch.as_tensor(self.sigma, device=img_t.device) - gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx) + gaussian_filter = GaussianFilter(get_spatial_ndim(img), sigma, approx=self.approx) out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0) out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype) @@ -1673,7 +1674,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if not self._do_transform: return img - sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=img.ndim - 1) + sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=get_spatial_ndim(img)) return GaussianSmooth(sigma=sigma, approx=self.approx)(img) @@ -1723,7 +1724,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) gf1, gf2 = ( - GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device) + GaussianFilter(get_spatial_ndim(img), sigma, approx=self.approx).to(img_t.device) for sigma in (self.sigma1, self.sigma2) ) blurred_f = gf1(img_t.unsqueeze(0)) @@ -1811,8 +1812,9 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if self.x2 is None or self.y2 is None or self.z2 is None or self.a is None: raise RuntimeError("please call the `randomize()` function first.") - sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=img.ndim - 1) - sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=img.ndim - 1) + _sp = get_spatial_ndim(img) + sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=_sp) + sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=_sp) return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index ecf918f47a..fdde1684bb 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -213,7 +213,7 @@ def track_transform_meta( orig_affine = data_t.peek_pending_affine() orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0] try: - affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64) + affine = orig_affine @ to_affine_nd(orig_affine.shape[-1] - 1, affine, dtype=torch.float64) except RuntimeError as e: if orig_affine.ndim > 2: if data_t.is_batch: diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index a33d76807c..d693ba6810 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -256,9 +256,11 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, if not pending: return data, [] + _rank = data.spatial_ndim if isinstance(data, MetaTensor) else 3 + cumulative_xform = affine_from_pending(pending[0]) - if cumulative_xform.shape[0] == 3: - cumulative_xform = to_affine_nd(3, cumulative_xform) + if cumulative_xform.shape[0] < _rank + 1: + cumulative_xform = to_affine_nd(_rank, cumulative_xform) cur_kwargs = kwargs_from_pending(pending[0]) override_kwargs: dict[str, Any] = {} @@ -283,8 +285,8 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None, data = resample(data.to(device), cumulative_xform, _cur_kwargs) next_matrix = affine_from_pending(p) - if next_matrix.shape[0] == 3: - next_matrix = to_affine_nd(3, next_matrix) + if next_matrix.shape[0] < _rank + 1: + next_matrix = to_affine_nd(_rank, next_matrix) cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 47623b748d..3b5d38cf52 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -23,7 +23,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.networks import one_hot from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering from monai.transforms.inverse import InvertibleTransform @@ -624,7 +624,11 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) img_: torch.Tensor = convert_to_tensor(img, track_meta=False) - spatial_dims = len(img_.shape) - 1 + spatial_dims = get_spatial_ndim(img) + # Validate actual tensor shape against tracked spatial_ndim + actual_spatial = img_.ndim - 1 # channel-first layout + if actual_spatial != spatial_dims: + spatial_dims = actual_spatial img_ = img_.unsqueeze(0) # adds a batch dim if spatial_dims == 2: kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32) @@ -1104,7 +1108,7 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor: image_tensor = convert_to_tensor(image, track_meta=get_track_meta()) # Check/set spatial axes - n_spatial_dims = image_tensor.ndim - 1 # excluding the channel dimension + n_spatial_dims = get_spatial_ndim(image_tensor) valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0)) # Check gradient axes to be valid diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index b6bf211cc4..8c228dbb0a 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -27,7 +27,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.box_utils import BoxMode, StandardMode from monai.data.meta_obj import get_track_meta, set_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij @@ -848,12 +848,14 @@ def __call__( anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma - input_ndim = img.ndim - 1 # spatial ndim + input_ndim = get_spatial_ndim(img) if self.size_mode == "all": output_ndim = len(ensure_tuple(self.spatial_size)) if output_ndim > input_ndim: input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) img = img.reshape(input_shape) + if isinstance(img, MetaTensor): + img.spatial_ndim = output_ndim elif output_ndim < input_ndim: raise ValueError( "len(spatial_size) must be greater or equal to img spatial dimensions, " @@ -1034,6 +1036,9 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0] if isinstance(out, MetaTensor): affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False) + # Use affine matrix shape directly (not spatial_ndim) because the affine may be + # larger than the spatial dimensions (e.g., 4x4 for 2D data), and we need to match + # the actual affine matrix rank being composed mat = to_affine_nd(len(affine) - 1, transform_t) out.affine @= convert_to_dst_type(mat, affine)[0] return out @@ -1131,7 +1136,7 @@ def __call__( during initialization for this call. Defaults to None. """ img = convert_to_tensor(img, track_meta=get_track_meta()) - _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim + _zoom = ensure_tuple_rep(self.zoom, get_spatial_ndim(img)) _mode = self.mode if mode is None else mode _padding_mode = padding_mode or self.padding_mode _align_corners = self.align_corners if align_corners is None else align_corners @@ -1519,7 +1524,7 @@ def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: return None - self._axis = self.R.randint(data.ndim - 1) + self._axis = self.R.randint(get_spatial_ndim(data)) def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: """ @@ -1629,13 +1634,14 @@ def randomize(self, img: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: return None + _sp = get_spatial_ndim(img) self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] if len(self._zoom) == 1: # to keep the spatial shape ratio, use same random zoom factor for all dims - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) - elif len(self._zoom) == 2 and img.ndim > 3: + self._zoom = ensure_tuple_rep(self._zoom[0], _sp) + elif len(self._zoom) == 2 and _sp > 2: # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) + self._zoom = ensure_tuple_rep(self._zoom[0], _sp - 1) + ensure_tuple(self._zoom[-1]) def __call__( self, @@ -2349,6 +2355,8 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: out = MetaTensor(out) out.meta = data.meta # type: ignore affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] + # Use affine matrix shape directly (not spatial_ndim) to ensure matrix composition compatibility + # when affine is larger than spatial dimensions (e.g., 4x4 for 2D data) xform, *_ = convert_to_dst_type( Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine ) @@ -2618,6 +2626,8 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: out = MetaTensor(out) out.meta = data.meta # type: ignore affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] + # Use affine matrix shape directly (not spatial_ndim) to ensure matrix composition compatibility + # when affine is larger than spatial dimensions (e.g., 4x4 for 2D data) xform, *_ = convert_to_dst_type( Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine ) @@ -3032,10 +3042,11 @@ def __call__( raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") all_ranges = [] - num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1) + _sp = get_spatial_ndim(img) + num_cells = ensure_tuple_rep(self.num_cells, _sp) if isinstance(img, MetaTensor) and img.pending_operations: warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.") - for dim_idx, dim_size in enumerate(img.shape[1:]): + for dim_idx, dim_size in enumerate(img.shape[1 : 1 + _sp]): dim_distort_steps = distort_steps[dim_idx] ranges = torch.zeros(dim_size, dtype=torch.float32) cell_size = dim_size // num_cells[dim_idx] diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 3001dd1e64..6561e23480 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -26,7 +26,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.data.box_utils import get_boxmode from monai.data.meta_obj import get_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.transforms.croppad.array import ResizeWithPadOrCrop @@ -99,9 +99,10 @@ def spatial_resample( src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4) img = convert_to_tensor(data=img, track_meta=get_track_meta()) # ensure spatial rank is <= 3 - spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) + max_rank = max(int(img.ndim) - 1, 1) + spatial_rank = min(get_spatial_ndim(img), max_rank, 3) if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None: - spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size + spatial_rank = min(len(ensure_tuple(spatial_size)), max_rank, 3) # infer spatial rank based on spatial_size src_affine = to_affine_nd(spatial_rank, src_affine).to(torch.float64) dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine dst_affine = convert_to_dst_type(dst_affine, src_affine)[0] diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3dc7897feb..9919b9a6eb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -30,7 +30,7 @@ from monai.config import DtypeLike from monai.config.type_definitions import NdarrayOrTensor from monai.data.meta_obj import get_track_meta -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps from monai.networks.layers.simplelayers import ( ApplyFilter, @@ -314,23 +314,26 @@ def __call__(self, img: torch.Tensor) -> list[torch.Tensor]: """ Apply the transform to `img`. """ - n_out = img.shape[self.dim] + dim = self.dim if self.dim >= 0 else self.dim + img.ndim + n_out = img.shape[dim] if isinstance(img, torch.Tensor): - outputs = list(torch.split(img, 1, self.dim)) + outputs = list(torch.split(img, 1, dim)) else: - outputs = np.split(img, n_out, self.dim) + outputs = np.split(img, n_out, dim) for idx, item in enumerate(outputs): if not self.keepdim: - outputs[idx] = item.squeeze(self.dim) + outputs[idx] = item.squeeze(dim) if self.update_meta and isinstance(img, MetaTensor): - if not isinstance(item, MetaTensor): - item = MetaTensor(item, meta=img.meta) - if self.dim == 0: # don't update affine if channel dim + out = outputs[idx] + if not isinstance(out, MetaTensor): + out = MetaTensor(out, meta=img.meta) + outputs[idx] = out + if dim == 0: # don't update affine if channel dim continue - ndim = len(item.affine) - shift = torch.eye(ndim, device=item.affine.device, dtype=item.affine.dtype) - shift[self.dim - 1, -1] = idx - item.affine = item.affine @ shift + ndim = len(out.affine) + shift = torch.eye(ndim, device=out.affine.device, dtype=out.affine.dtype) + shift[dim - 1, -1] = idx + out.affine = out.affine @ shift return outputs @@ -1506,8 +1509,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: Args: img: data to be transformed, assuming `img` is channel first. """ - if max(self.spatial_dims) > img.ndim - 2 or min(self.spatial_dims) < 0: - raise ValueError(f"`spatial_dims` values must be within [0, {img.ndim - 2}]") + _sp = get_spatial_ndim(img) + if max(self.spatial_dims) > _sp - 1 or min(self.spatial_dims) < 0: + raise ValueError(f"`spatial_dims` values must be within [0, {_sp - 1}]") spatial_size = img.shape[1:] coord_channels = np.array(np.meshgrid(*tuple(np.linspace(-0.5, 0.5, s) for s in spatial_size), indexing="ij")) @@ -1675,7 +1679,7 @@ def __call__( applied_operations = img.applied_operations img_, prev_type, device = convert_data_type(img, torch.Tensor) - ndim = img_.ndim - 1 # assumes channel first format + ndim = get_spatial_ndim(img) if isinstance(self.filter, str): self.filter = self._get_filter_from_string(self.filter, self.filter_size, ndim) # type: ignore diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index 427902f784..e9fb834d62 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -68,6 +68,7 @@ def check_ids(self, a, b, should_match): def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: self.assertEqual(a.is_batch, b.is_batch) + self.assertEqual(a.spatial_ndim, b.spatial_ndim) meta_a, meta_b = a.meta, b.meta # need to split affine from rest of metadata aff_a = meta_a.get("affine", None) diff --git a/tests/data/meta_tensor/test_spatial_ndim.py b/tests/data/meta_tensor/test_spatial_ndim.py new file mode 100644 index 0000000000..d7ef9bd9f7 --- /dev/null +++ b/tests/data/meta_tensor/test_spatial_ndim.py @@ -0,0 +1,201 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from copy import deepcopy +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.data import MetaTensor, get_spatial_ndim +from monai.data.utils import collate_meta_tensor_fn, decollate_batch +from monai.transforms import Affine, LabelToContour, RandAffine, RandZoom, Resize, Rotate, SqueezeDim +from monai.transforms.utility.array import SplitDim +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +# (shape, affine, expected_spatial_ndim) +CONSTRUCTION_CASES = [ + ((1, 10, 10, 10), None, 3), # default eye(4) + ((1, 10, 10), torch.eye(3), 2), # eye(3) + ((1, 10), torch.eye(2), 1), # eye(2) +] + +# (description, op, expected_spatial_ndim) -- op takes a 2D MetaTensor and returns a new one +PRESERVATION_CASES = [ + ("reshape", lambda t: t.reshape(1, 100), 2), + ("unsqueeze", lambda t: t.unsqueeze(0), 2), + ("squeeze", lambda t: t.unsqueeze(1).squeeze(1), 2), + ("clone", lambda t: t.clone(), 2), + ("deepcopy", lambda t: deepcopy(t), 2), +] + + +class TestSpatialNdim(unittest.TestCase): + @parameterized.expand(CONSTRUCTION_CASES) + def test_construction(self, shape, affine, expected): + kwargs = {"affine": affine} if affine is not None else {} + t = MetaTensor(torch.randn(*shape), **kwargs) + self.assertEqual(t.spatial_ndim, expected) + + @parameterized.expand(PRESERVATION_CASES) + def test_preserved_through_op(self, _desc, op, expected): + t = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + t2 = op(t) + self.assertEqual(t2.spatial_ndim, expected) + + def test_setter_and_validation(self): + t = MetaTensor(torch.randn(1, 10, 10, 10)) + t.spatial_ndim = 2 + self.assertEqual(t.spatial_ndim, 2) + for bad in (0, -1): + with self.assertRaises(ValueError): + t.spatial_ndim = bad + + def test_affine_setter_syncs(self): + t = MetaTensor(torch.randn(1, 10, 10, 10)) + t.affine = torch.eye(3) + self.assertEqual(t.spatial_ndim, 2) + + def test_copy_from_meta_tensor(self): + t1 = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + self.assertEqual(MetaTensor(t1).spatial_ndim, 2) + + def test_collate_and_decollate(self): + t1 = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + t2 = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + batch = collate_meta_tensor_fn([t1, t2]) + self.assertEqual(batch.spatial_ndim, 2) + for item in decollate_batch(batch): + self.assertIsInstance(item, MetaTensor) + self.assertEqual(item.spatial_ndim, 2) + + def test_derived_properties(self): + """peek_pending_rank, peek_pending_shape, and pixdim all respect spatial_ndim.""" + aff = torch.diag(torch.tensor([2.0, 3.0, 1.0], dtype=torch.float64)) + t = MetaTensor(torch.randn(1, 10, 10), affine=aff) + self.assertEqual(t.peek_pending_rank(), 2) + self.assertEqual(t.peek_pending_shape(), (10, 10)) + self.assertEqual(len(t.pixdim), 2) + + def test_squeeze_dim_transform(self): + t = MetaTensor(torch.randn(1, 10, 1, 10)) + result = SqueezeDim(dim=2)(t) + self.assertEqual(result.spatial_ndim, result.affine.shape[-1] - 1) + + def test_splitdim_channel_dim_no_decrement(self): + t = MetaTensor(torch.randn(3, 8, 7)) + for item in SplitDim(dim=0, keepdim=False)(t): + if isinstance(item, MetaTensor): + self.assertEqual(item.spatial_ndim, 2) + + def test_lazy_apply_pending_2d(self): + """apply_pending uses spatial_ndim for 2D data instead of hardcoded 3.""" + from monai.transforms.lazy.functional import apply_pending + from monai.utils.enums import LazyAttr + + t = MetaTensor(torch.randn(1, 10, 10), affine=torch.eye(3)) + self.assertEqual(t.spatial_ndim, 2) + # Push a pending 2D affine operation + pending_op = { + LazyAttr.AFFINE: torch.eye(3, dtype=torch.float64), + LazyAttr.SHAPE: (10, 10), + LazyAttr.INTERP_MODE: "bilinear", + LazyAttr.PADDING_MODE: "zeros", + } + t.push_pending_operation(pending_op) + result, applied = apply_pending(t, overrides={"mode": "bilinear"}) + self.assertIsInstance(result, MetaTensor) + self.assertEqual(len(applied), 1) + + def test_batch_slice_clamps_spatial_ndim(self): + t = MetaTensor(torch.randn(10, 6, 5, 7), affine=torch.eye(4)) + t.is_batch = True + t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1) + self.assertEqual(t.spatial_ndim, 3) + sliced = t[0] + self.assertEqual(sliced.shape, (6, 5, 7)) + self.assertEqual(sliced.spatial_ndim, 2) + self.assertEqual(get_spatial_ndim(sliced), 2) + + def test_label_to_contour_batch_slice_2d(self): + t = MetaTensor(torch.randint(0, 2, (10, 6, 5, 7)).float(), affine=torch.eye(4)) + t.is_batch = True + t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1) + sliced = t[0] + out = LabelToContour()(sliced) + self.assertEqual(out.shape, sliced.shape) + + def test_rand_zoom_batch_slice_2d(self): + t = MetaTensor(torch.randn(10, 1, 64, 64), affine=torch.eye(4)) + t.is_batch = True + t.meta["affine"] = torch.eye(4)[None].repeat(10, 1, 1) + sliced = t[0] + zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=1.2) + zoom.set_random_state(seed=0) + zoom.randomize(sliced) + self.assertEqual(len(zoom._zoom), 2) + out = zoom(sliced) + self.assertEqual(out.ndim, sliced.ndim) + + @skipUnless(has_einops, "Requires einops") + def test_einops_rearrange_then_resize(self): + """Reproduce the exact #6397 bug: einops.rearrange -> Resize.""" + from einops import rearrange + + x = MetaTensor(torch.randn(1, 1, 64, 64, 3)) + x.is_batch = True + x.meta["affine"] = torch.eye(4)[None] + x_ = rearrange(x, "b c h w d -> (b c) h w d") + self.assertIsInstance(x_, MetaTensor) + self.assertEqual(x_.spatial_ndim, 3) + out = Resize(spatial_size=(32, 32, 3), mode="trilinear", align_corners=True)(x_) + self.assertEqual(out.shape[-3:], (32, 32, 3)) + + def test_affine_inverse_2d_metatensor(self): + """Affine.inverse on 2D data: 4x4 affine with spatial_ndim=2.""" + img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4)) + self.assertEqual(img.spatial_ndim, 2) + xform = Affine(rotate_params=(np.pi / 6,), padding_mode="zeros", image_only=True) + result = xform(img) + inv = xform.inverse(result) + self.assertEqual(inv.shape, img.shape) + self.assertEqual(len(inv.applied_operations), 0) + + def test_rotate_inverse_2d_metatensor(self): + """Rotate.inverse on 2D data: 4x4 affine with spatial_ndim=2.""" + img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4)) + self.assertEqual(img.spatial_ndim, 2) + xform = Rotate(angle=(np.pi / 4,), padding_mode="zeros") + result = xform(img) + inv = xform.inverse(result) + self.assertEqual(inv.shape, img.shape) + self.assertEqual(len(inv.applied_operations), 0) + + def test_rand_affine_inverse_2d_metatensor(self): + """RandAffine.inverse on 2D data: 4x4 affine with spatial_ndim=2.""" + img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4)) + self.assertEqual(img.spatial_ndim, 2) + xform = RandAffine(prob=1.0, rotate_range=(np.pi / 6,), padding_mode="zeros") + xform.set_random_state(seed=42) + result = xform(img) + inv = xform.inverse(result) + self.assertEqual(inv.shape, img.shape) + self.assertEqual(len(inv.applied_operations), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/test_squeezedim.py b/tests/transforms/test_squeezedim.py index 5fd333d821..8e838629f4 100644 --- a/tests/transforms/test_squeezedim.py +++ b/tests/transforms/test_squeezedim.py @@ -38,6 +38,7 @@ def test_shape(self, input_param, test_data, expected_shape): self.assertTupleEqual(result.shape, expected_shape) if "dim" in input_param and input_param["dim"] == 2 and isinstance(result, MetaTensor): assert_allclose(result.affine.shape, [3, 3]) + self.assertEqual(result.spatial_ndim, result.affine.shape[-1] - 1) @parameterized.expand(TESTS_FAIL) def test_invalid_inputs(self, exception, input_param, test_data): diff --git a/tests/transforms/utility/test_splitdim.py b/tests/transforms/utility/test_splitdim.py index 31d9983a2b..b15afb24f2 100644 --- a/tests/transforms/utility/test_splitdim.py +++ b/tests/transforms/utility/test_splitdim.py @@ -16,6 +16,7 @@ import numpy as np from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms.utility.array import SplitDim from tests.test_utils import TEST_NDARRAYS @@ -47,6 +48,39 @@ def test_singleton(self): out = SplitDim(dim=1)(arr) self.assertEqual(out[0].shape, shape) + def test_spatial_ndim_decremented(self): + """spatial_ndim decremented for keepdim=False on spatial dim.""" + import torch + + arr = MetaTensor(torch.randn(2, 3, 8, 7)) + self.assertEqual(arr.spatial_ndim, 3) + out = SplitDim(dim=1, keepdim=False)(arr) + for item in out: + self.assertIsInstance(item, MetaTensor) + self.assertEqual(item.spatial_ndim, 2) + + def test_spatial_ndim_negative_dim(self): + """spatial_ndim decremented for keepdim=False with negative dim.""" + import torch + + arr = MetaTensor(torch.randn(2, 3, 8, 7)) + self.assertEqual(arr.spatial_ndim, 3) + out = SplitDim(dim=-1, keepdim=False)(arr) + for item in out: + self.assertIsInstance(item, MetaTensor) + self.assertEqual(item.spatial_ndim, 2) + + def test_spatial_ndim_channel_dim_no_decrement(self): + """spatial_ndim not decremented for keepdim=False on channel dim (dim=0).""" + import torch + + arr = MetaTensor(torch.randn(3, 8, 7)) + self.assertEqual(arr.spatial_ndim, 2) + out = SplitDim(dim=0, keepdim=False)(arr) + for item in out: + self.assertIsInstance(item, MetaTensor) + self.assertEqual(item.spatial_ndim, 2) + if __name__ == "__main__": unittest.main()