Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
Expand Down Expand Up @@ -3565,33 +3565,35 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:

if self._do_transform:
input_shape = img.shape[1:]
target_shape = tuple(np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_).tolist())

resize_tfm_downsample = Resize(
spatial_size=target_shape, size_mode="all", mode=self.downsample_mode, anti_aliasing=False
)

resize_tfm_upsample = Resize(
spatial_size=input_shape,
size_mode="all",
mode=self.upsample_mode,
anti_aliasing=False,
align_corners=self.align_corners,
# Clamp each axis to at least 1 so F.interpolate never sees a zero-sized dimension.
target_shape = tuple(max(1, int(np.round(s * self.zoom_factor))) for s in input_shape)

# Use F.interpolate directly on a plain tensor to avoid mutating the global
# set_track_meta flag, which is not thread-safe (see GitHub issue #8409).
img_t = convert_to_tensor(img, track_meta=False)
# F.interpolate requires float input and a batch dimension; cast matches
# the default dtype=float32 that Resize uses internally.
img_float = img_t.unsqueeze(0).to(dtype=torch.float32)

downsample_mode = str(self.downsample_mode)
upsample_mode = str(self.upsample_mode)
# align_corners is only valid for linear/bilinear/bicubic/trilinear modes
_align_corners_modes = {"linear", "bilinear", "bicubic", "trilinear"}
downsample_align_corners = self.align_corners if downsample_mode in _align_corners_modes else None
upsample_align_corners = self.align_corners if upsample_mode in _align_corners_modes else None

img_downsampled = torch.nn.functional.interpolate(
img_float, size=target_shape, mode=downsample_mode, align_corners=downsample_align_corners
)
Comment thread
ericspod marked this conversation as resolved.
# temporarily disable metadata tracking, since we do not want to invert the two Resize functions during
# post-processing
original_tack_meta_value = get_track_meta()
set_track_meta(False)

img_downsampled = resize_tfm_downsample(img)
img_upsampled = resize_tfm_upsample(img_downsampled)

# reset metadata tracking to original value
set_track_meta(original_tack_meta_value)

# copy metadata from original image to down-and-upsampled image
img_upsampled = MetaTensor(img_upsampled)
img_upsampled.copy_meta_from(img)
img_upsampled_t = torch.nn.functional.interpolate(
img_downsampled, size=input_shape, mode=upsample_mode, align_corners=upsample_align_corners
).squeeze(0)
Comment thread
chhayankjain marked this conversation as resolved.

# copy metadata from original image to down-and-upsampled image,
# respecting the caller's get_track_meta() setting.
img_upsampled = cast(torch.Tensor, convert_to_tensor(img_upsampled_t, track_meta=get_track_meta()))
if isinstance(img_upsampled, MetaTensor):
img_upsampled.copy_meta_from(img)

return img_upsampled

Expand Down
45 changes: 45 additions & 0 deletions tests/transforms/test_rand_simulate_low_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@

from __future__ import annotations

import threading
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.data.meta_obj import get_track_meta
from monai.transforms import RandSimulateLowResolution
from tests.test_utils import TEST_NDARRAYS, assert_allclose

Expand Down Expand Up @@ -78,6 +81,48 @@ def test_value(self, arguments, image, expected_data):
result = randsimlowres(image)
assert_allclose(result, expected_data, rtol=1e-4, type_test="tensor")

def test_track_meta_global_state_unchanged(self):
# Verify that calling RandSimulateLowResolution does not modify the global
# set_track_meta flag (regression test for GitHub issue #8409).
img = torch.ones(1, 4, 4, 4)
tfm = RandSimulateLowResolution(prob=1.0, zoom_range=(0.5, 0.6))
tfm.set_random_state(seed=0)

original_track_meta = get_track_meta()
tfm(img)
self.assertEqual(get_track_meta(), original_track_meta, "set_track_meta global state was unexpectedly modified")

def test_thread_safety(self):
# Verify that concurrent calls do not corrupt each other's track_meta state
# (regression test for GitHub issue #8409).
# expected_track_meta is captured before threads start so every worker
# checks against the same baseline rather than its own (possibly already
# corrupted) snapshot.
errors = []
expected_track_meta = get_track_meta()
start_barrier = threading.Barrier(8)

def run_transform():
img = torch.ones(1, 4, 4, 4)
tfm = RandSimulateLowResolution(prob=1.0, zoom_range=(0.5, 0.6))
start_barrier.wait() # synchronise so all threads hammer the transform at once
try:
for _ in range(50):
tfm(img)
if get_track_meta() != expected_track_meta:
errors.append(RuntimeError("track_meta state changed in thread"))
break
except Exception as e:
errors.append(e)

threads = [threading.Thread(target=run_transform) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()

self.assertEqual(errors, [], f"Thread safety errors: {errors}")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
chhayankjain marked this conversation as resolved.


if __name__ == "__main__":
unittest.main()
Loading