diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index d6b5ed8641..e83750e414 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -26,7 +26,7 @@ from monai.config import USE_COMPILED, DtypeLike from monai.config.type_definitions import NdarrayOrTensor from monai.data.box_utils import BoxMode, StandardMode -from monai.data.meta_obj import get_track_meta, set_track_meta +from monai.data.meta_obj import get_track_meta from monai.data.meta_tensor import MetaTensor from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull @@ -3565,33 +3565,35 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: if self._do_transform: input_shape = img.shape[1:] - target_shape = tuple(np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_).tolist()) - - resize_tfm_downsample = Resize( - spatial_size=target_shape, size_mode="all", mode=self.downsample_mode, anti_aliasing=False - ) - - resize_tfm_upsample = Resize( - spatial_size=input_shape, - size_mode="all", - mode=self.upsample_mode, - anti_aliasing=False, - align_corners=self.align_corners, + # Clamp each axis to at least 1 so F.interpolate never sees a zero-sized dimension. + target_shape = tuple(max(1, int(np.round(s * self.zoom_factor))) for s in input_shape) + + # Use F.interpolate directly on a plain tensor to avoid mutating the global + # set_track_meta flag, which is not thread-safe (see GitHub issue #8409). + img_t = convert_to_tensor(img, track_meta=False) + # F.interpolate requires float input and a batch dimension; cast matches + # the default dtype=float32 that Resize uses internally. + img_float = img_t.unsqueeze(0).to(dtype=torch.float32) + + downsample_mode = str(self.downsample_mode) + upsample_mode = str(self.upsample_mode) + # align_corners is only valid for linear/bilinear/bicubic/trilinear modes + _align_corners_modes = {"linear", "bilinear", "bicubic", "trilinear"} + downsample_align_corners = self.align_corners if downsample_mode in _align_corners_modes else None + upsample_align_corners = self.align_corners if upsample_mode in _align_corners_modes else None + + img_downsampled = torch.nn.functional.interpolate( + img_float, size=target_shape, mode=downsample_mode, align_corners=downsample_align_corners ) - # temporarily disable metadata tracking, since we do not want to invert the two Resize functions during - # post-processing - original_tack_meta_value = get_track_meta() - set_track_meta(False) - - img_downsampled = resize_tfm_downsample(img) - img_upsampled = resize_tfm_upsample(img_downsampled) - - # reset metadata tracking to original value - set_track_meta(original_tack_meta_value) - - # copy metadata from original image to down-and-upsampled image - img_upsampled = MetaTensor(img_upsampled) - img_upsampled.copy_meta_from(img) + img_upsampled_t = torch.nn.functional.interpolate( + img_downsampled, size=input_shape, mode=upsample_mode, align_corners=upsample_align_corners + ).squeeze(0) + + # copy metadata from original image to down-and-upsampled image, + # respecting the caller's get_track_meta() setting. + img_upsampled = cast(torch.Tensor, convert_to_tensor(img_upsampled_t, track_meta=get_track_meta())) + if isinstance(img_upsampled, MetaTensor): + img_upsampled.copy_meta_from(img) return img_upsampled diff --git a/tests/transforms/test_rand_simulate_low_resolution.py b/tests/transforms/test_rand_simulate_low_resolution.py index 3a8032d152..46b3187fd8 100644 --- a/tests/transforms/test_rand_simulate_low_resolution.py +++ b/tests/transforms/test_rand_simulate_low_resolution.py @@ -11,11 +11,14 @@ from __future__ import annotations +import threading import unittest import numpy as np +import torch from parameterized import parameterized +from monai.data.meta_obj import get_track_meta from monai.transforms import RandSimulateLowResolution from tests.test_utils import TEST_NDARRAYS, assert_allclose @@ -78,6 +81,48 @@ def test_value(self, arguments, image, expected_data): result = randsimlowres(image) assert_allclose(result, expected_data, rtol=1e-4, type_test="tensor") + def test_track_meta_global_state_unchanged(self): + # Verify that calling RandSimulateLowResolution does not modify the global + # set_track_meta flag (regression test for GitHub issue #8409). + img = torch.ones(1, 4, 4, 4) + tfm = RandSimulateLowResolution(prob=1.0, zoom_range=(0.5, 0.6)) + tfm.set_random_state(seed=0) + + original_track_meta = get_track_meta() + tfm(img) + self.assertEqual(get_track_meta(), original_track_meta, "set_track_meta global state was unexpectedly modified") + + def test_thread_safety(self): + # Verify that concurrent calls do not corrupt each other's track_meta state + # (regression test for GitHub issue #8409). + # expected_track_meta is captured before threads start so every worker + # checks against the same baseline rather than its own (possibly already + # corrupted) snapshot. + errors = [] + expected_track_meta = get_track_meta() + start_barrier = threading.Barrier(8) + + def run_transform(): + img = torch.ones(1, 4, 4, 4) + tfm = RandSimulateLowResolution(prob=1.0, zoom_range=(0.5, 0.6)) + start_barrier.wait() # synchronise so all threads hammer the transform at once + try: + for _ in range(50): + tfm(img) + if get_track_meta() != expected_track_meta: + errors.append(RuntimeError("track_meta state changed in thread")) + break + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=run_transform) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(errors, [], f"Thread safety errors: {errors}") + if __name__ == "__main__": unittest.main()