Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
monai_to_itk_ddf,
)
from .meta_obj import MetaObj, get_track_meta, set_track_meta
from .meta_tensor import MetaTensor
from .meta_tensor import MetaTensor, get_spatial_ndim
from .samplers import DistributedSampler, DistributedWeightedRandomSampler
from .synthetic import create_test_image_2d, create_test_image_3d
from .test_time_augmentation import TestTimeAugmentation
Expand Down
4 changes: 4 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

_TRACK_META = True

# Default number of spatial dimensions for medical imaging (3D volumetric data)
_DEFAULT_SPATIAL_NDIM = 3

__all__ = ["get_track_meta", "set_track_meta", "MetaObj"]


Expand Down Expand Up @@ -84,6 +87,7 @@ def __init__(self) -> None:
self._applied_operations: list = MetaObj.get_default_applied_operations()
self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops
self._is_batch: bool = False
self._spatial_ndim: int = 3 # default: 3 spatial dimensions

@staticmethod
def flatten_meta_objs(*args: Iterable):
Expand Down
103 changes: 91 additions & 12 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,59 @@

import functools
import warnings
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any

import numpy as np
import torch

import monai
from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor

__all__ = ["MetaTensor"]
__all__ = ["MetaTensor", "get_spatial_ndim"]


def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int:
"""Clamp spatial dims to a valid range for the current tensor shape."""
limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1)
return max(1, min(int(spatial_ndim), limit))


def _has_explicit_no_channel(meta: Mapping | None) -> bool:
return (
isinstance(meta, Mapping)
and MetaKeys.ORIGINAL_CHANNEL_DIM in meta
and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM])
)


def get_spatial_ndim(img: NdarrayOrTensor) -> int:
"""Return the number of spatial dimensions assuming channel-first layout.

Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to
``img.ndim - 1``. Always assumes channel-first (``no_channel=False``)
because callers run after ``EnsureChannelFirst`` has already added one.
"""
if isinstance(img, MetaTensor):
return _normalize_spatial_ndim(img.spatial_ndim, img.ndim)
return img.ndim - 1


def _is_batch_only_index(index: Any) -> bool:
"""True when indexing pattern selects only the batch axis (e.g., ``x[0]`` or ``x[0, ...]``)."""
if isinstance(index, (int, np.integer)):
return True
if not isinstance(index, Sequence) or not index:
return False
if not isinstance(index[0], (int, np.integer)):
return False
return all(i in (slice(None, None, None), Ellipsis, None) for i in index[1:])


@functools.lru_cache(None)
Expand Down Expand Up @@ -111,6 +148,7 @@ def __new__(
meta: dict | None = None,
applied_operations: list | None = None,
*args,
spatial_ndim: int | None = None,
**kwargs,
) -> MetaTensor:
_kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {}
Expand All @@ -123,6 +161,7 @@ def __init__(
meta: dict | None = None,
applied_operations: list | None = None,
*_args,
spatial_ndim: int | None = None,
**_kwargs,
) -> None:
"""
Expand All @@ -134,6 +173,8 @@ def __init__(
the list is typically maintained by `monai.transforms.TraceableTransform`.
See also: :py:class:`monai.transforms.TraceableTransform`
_args: additional args (currently not in use in this constructor).
spatial_ndim: optional number of spatial dimensions. If ``None``, derived
from the affine matrix clamped by the tensor shape.
_kwargs: additional kwargs (currently not in use in this constructor).

Note:
Expand All @@ -158,6 +199,14 @@ def __init__(
self.affine = self.meta[MetaKeys.AFFINE]
else:
self.affine = self.get_default_affine()
# Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape.
# This cached value is kept in sync via the affine setter for hot-path performance.
no_channel = _has_explicit_no_channel(self.meta)
if spatial_ndim is not None:
self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel)
elif self.affine.ndim == 2:
self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel)

# applied_operations
if applied_operations is not None:
self.applied_operations = applied_operations
Expand Down Expand Up @@ -237,6 +286,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
if func == torch.Tensor.__getitem__:
if idx > 0 or len(args) < 2 or len(args[0]) < 1:
return ret
full_idx = args[1]
batch_idx = args[1][0] if isinstance(args[1], Sequence) else args[1]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
Expand All @@ -258,6 +308,8 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
ret_meta.is_batch = False
if hasattr(ret_meta, "__dict__"):
ret.__dict__ = ret_meta.__dict__.copy()
if _is_batch_only_index(full_idx):
ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim)
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
# But we only want to split the batch if the `unbind` is along the 0th dimension.
elif func == torch.Tensor.unbind:
Expand Down Expand Up @@ -467,15 +519,40 @@ def affine(self) -> torch.Tensor:

@affine.setter
def affine(self, d: NdarrayTensor) -> None:
"""Set the affine."""
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
"""Set the affine.

When setting a non-batched affine matrix, automatically synchronizes the cached
spatial_ndim attribute to maintain consistency between the affine matrix (source of truth)
and the cached spatial dimension count.
"""
a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
self.meta[MetaKeys.AFFINE] = a
if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth)
no_channel = _has_explicit_no_channel(self.meta)
self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel)

@property
def spatial_ndim(self) -> int:
"""Get the number of spatial dimensions.

This value is cached for hot-path performance and is kept in sync with the affine matrix
via the affine setter. The affine matrix is the source of truth for spatial dimensions.
"""
return getattr(self, "_spatial_ndim", _DEFAULT_SPATIAL_NDIM)

@spatial_ndim.setter
def spatial_ndim(self, val: int) -> None:
"""Set the number of spatial dimensions."""
if val < 1:
raise ValueError(f"spatial_ndim must be >= 1, got {val}")
self._spatial_ndim = val
Comment on lines +544 to +548
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Harden spatial_ndim setter to enforce integer rank semantics.

Line 518 only checks < 1; non-integer values can slip through and destabilize rank-dependent operations.

Proposed fix
+from numbers import Integral
...
     `@spatial_ndim.setter`
     def spatial_ndim(self, val: int) -> None:
         """Set the number of spatial dimensions."""
+        if not isinstance(val, Integral):
+            raise TypeError(f"spatial_ndim must be an integer, got {type(val).__name__}")
         if val < 1:
             raise ValueError(f"spatial_ndim must be >= 1, got {val}")
-        self._spatial_ndim = val
+        self._spatial_ndim = _normalize_spatial_ndim(int(val), self.ndim)
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 519-519: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/data/meta_tensor.py` around lines 516 - 520, The spatial_ndim setter
currently only checks val < 1 which allows non-integer values; update the setter
for spatial_ndim to enforce integer rank semantics by validating the type and
range: ensure the incoming val is an integer (and not a bool) and then check val
>= 1, otherwise raise a TypeError for non-integers or a ValueError for values <
1; modify the spatial_ndim setter method to perform these checks before
assigning to self._spatial_ndim.


@property
def pixdim(self):
"""Get the spacing"""
if self.is_batch:
return [affine_to_spacing(a) for a in self.affine]
return affine_to_spacing(self.affine)
return [affine_to_spacing(a, r=self.spatial_ndim) for a in self.affine]
return affine_to_spacing(self.affine, r=self.spatial_ndim)

def peek_pending_shape(self):
"""
Expand All @@ -490,7 +567,7 @@ def peek_pending_shape(self):

def peek_pending_affine(self):
res = self.affine
r = len(res) - 1
r = res.shape[-1] - 1 if res.ndim >= 2 else self.spatial_ndim
if r not in (2, 3):
warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.")
for p in self.pending_operations:
Expand All @@ -503,8 +580,10 @@ def peek_pending_affine(self):
return res

def peek_pending_rank(self):
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine
return 1 if a is None else int(max(1, len(a) - 1))
if self.pending_operations:
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None)
return 1 if a is None else int(max(1, len(a) - 1))
return self.spatial_ndim

def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override]
"""
Expand Down
3 changes: 2 additions & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch.utils.data._utils.collate import default_collate

from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
from monai.data.meta_obj import MetaObj
from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj
from monai.utils import (
MAX_SEED,
BlendMode,
Expand Down Expand Up @@ -432,6 +432,7 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
collated.meta = default_collate(meta_dicts)
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
collated.is_batch = True
collated.spatial_ndim = min(getattr(batch[0], "spatial_ndim", _DEFAULT_SPATIAL_NDIM), max(collated.ndim - 1, 1))
return collated


Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.meta_tensor import MetaTensor, get_spatial_ndim
from monai.data.utils import to_affine_nd
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import convert_pad_mode, create_translate
Expand Down Expand Up @@ -132,7 +132,7 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,
mode: the padding mode.
kwargs: other arguments for the `np.pad` or `torch.pad` function.
"""
ndim = len(img.shape) - 1
ndim = get_spatial_ndim(img)
matrix_np = np.round(to_affine_nd(ndim, convert_to_numpy(translation_mat, wrap_sequence=True).copy()))
matrix_np = to_affine_nd(len(spatial_size), matrix_np)
cc = np.asarray(np.meshgrid(*[[0.5, x - 0.5] for x in spatial_size], indexing="ij"))
Expand Down
14 changes: 8 additions & 6 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.config import DtypeLike
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import get_spatial_ndim
from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
Expand Down Expand Up @@ -1580,7 +1581,7 @@ def __init__(self, radius: Sequence[int] | int = 1) -> None:
def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
spatial_dims = img_t.ndim - 1
spatial_dims = get_spatial_ndim(img)
r = ensure_tuple_rep(self.radius, spatial_dims)
median_filter_instance = MedianFilter(r, spatial_dims=spatial_dims)
out_t: torch.Tensor = median_filter_instance(img_t)
Expand Down Expand Up @@ -1616,7 +1617,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma]
else:
sigma = torch.as_tensor(self.sigma, device=img_t.device)
gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx)
gaussian_filter = GaussianFilter(get_spatial_ndim(img), sigma, approx=self.approx)
out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0)
out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype)

Expand Down Expand Up @@ -1673,7 +1674,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
if not self._do_transform:
return img

sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=img.ndim - 1)
sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=get_spatial_ndim(img))
return GaussianSmooth(sigma=sigma, approx=self.approx)(img)


Expand Down Expand Up @@ -1723,7 +1724,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32)

gf1, gf2 = (
GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device)
GaussianFilter(get_spatial_ndim(img), sigma, approx=self.approx).to(img_t.device)
for sigma in (self.sigma1, self.sigma2)
)
blurred_f = gf1(img_t.unsqueeze(0))
Expand Down Expand Up @@ -1811,8 +1812,9 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen

if self.x2 is None or self.y2 is None or self.z2 is None or self.a is None:
raise RuntimeError("please call the `randomize()` function first.")
sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=img.ndim - 1)
sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=img.ndim - 1)
_sp = get_spatial_ndim(img)
sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=_sp)
sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=_sp)
return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img)


Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def track_transform_meta(
orig_affine = data_t.peek_pending_affine()
orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]
try:
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
affine = orig_affine @ to_affine_nd(orig_affine.shape[-1] - 1, affine, dtype=torch.float64)
except RuntimeError as e:
if orig_affine.ndim > 2:
if data_t.is_batch:
Expand Down
10 changes: 6 additions & 4 deletions monai/transforms/lazy/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,11 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
if not pending:
return data, []

_rank = data.spatial_ndim if isinstance(data, MetaTensor) else 3

cumulative_xform = affine_from_pending(pending[0])
if cumulative_xform.shape[0] == 3:
cumulative_xform = to_affine_nd(3, cumulative_xform)
if cumulative_xform.shape[0] < _rank + 1:
cumulative_xform = to_affine_nd(_rank, cumulative_xform)

cur_kwargs = kwargs_from_pending(pending[0])
override_kwargs: dict[str, Any] = {}
Expand All @@ -283,8 +285,8 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
data = resample(data.to(device), cumulative_xform, _cur_kwargs)

next_matrix = affine_from_pending(p)
if next_matrix.shape[0] == 3:
next_matrix = to_affine_nd(3, next_matrix)
if next_matrix.shape[0] < _rank + 1:
next_matrix = to_affine_nd(_rank, next_matrix)

cumulative_xform = combine_transforms(cumulative_xform, next_matrix)
cur_kwargs.update(new_kwargs)
Expand Down
10 changes: 7 additions & 3 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.meta_tensor import MetaTensor, get_spatial_ndim
from monai.networks import one_hot
from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering
from monai.transforms.inverse import InvertibleTransform
Expand Down Expand Up @@ -624,7 +624,11 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
spatial_dims = len(img_.shape) - 1
spatial_dims = get_spatial_ndim(img)
# Validate actual tensor shape against tracked spatial_ndim
actual_spatial = img_.ndim - 1 # channel-first layout
if actual_spatial != spatial_dims:
spatial_dims = actual_spatial
img_ = img_.unsqueeze(0) # adds a batch dim
if spatial_dims == 2:
kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32)
Expand Down Expand Up @@ -1104,7 +1108,7 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
image_tensor = convert_to_tensor(image, track_meta=get_track_meta())

# Check/set spatial axes
n_spatial_dims = image_tensor.ndim - 1 # excluding the channel dimension
n_spatial_dims = get_spatial_ndim(image_tensor)
valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0))

# Check gradient axes to be valid
Expand Down
Loading
Loading