diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py index 6e08a88e59..fa714daad1 100644 --- a/monai/apps/detection/transforms/box_ops.py +++ b/monai/apps/detection/transforms/box_ops.py @@ -267,7 +267,7 @@ def convert_box_to_mask( boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b]) # apply to global mask slicing = [b] - slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore + slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type: ignore boxes_mask_np[tuple(slicing)] = boxes_only_mask return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0] diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 8d662df83d..4408d602bd 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None: raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2") root: str child_key: str - (root, _, child_key) = keys + root, _, child_key = keys if root not in self.ops: self.ops[root] = [{}] self.ops[root][0].update({child_key: None}) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 9fdee6acd0..fa9ba27096 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1948,7 +1948,7 @@ def create_workflow( """ _args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs) - (workflow_name, config_file) = _pop_args( + workflow_name, config_file = _pop_args( _args, workflow_name=ConfigWorkflow, config_file=None ) # the default workflow name is "ConfigWorkflow" if isinstance(workflow_name, str): diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 066cec41b7..21b24840b5 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -139,7 +139,7 @@ class DatasetFunc(Dataset): """ def __init__(self, data: Any, func: Callable, **kwargs) -> None: - super().__init__(data=None, transform=None) # type:ignore + super().__init__(data=None, transform=None) # type: ignore self.src = data self.func = func self.kwargs = kwargs @@ -1635,7 +1635,7 @@ def _cachecheck(self, item_transformed): return (_data, _meta) return _data else: - item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore + item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type: ignore for i, _item in enumerate(item_transformed): for k in _item: meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}") diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index b6771f2dcc..02975039b3 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -48,7 +48,7 @@ def stopping_fn_from_loss() -> Callable[[Engine], Any]: """ def stopping_fn(engine: Engine) -> Any: - return -engine.state.output # type:ignore + return -engine.state.output # type: ignore return stopping_fn diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index a451b1a770..4a60e438cf 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -320,7 +320,7 @@ def get_edge_surface_distance( edges_spacing = None if use_subvoxels: edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape)) - (edges_pred, edges_gt, *areas) = get_mask_edges( + edges_pred, edges_gt, *areas = get_mask_edges( y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False ) if not edges_gt.any(): diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index ecf918f47a..c254693e2c 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -119,6 +119,8 @@ def get_transform_info(self) -> dict: """ Return a dictionary with the relevant information pertaining to an applied transform. """ + self._init_trace_threadlocal() + vals = ( self.__class__.__name__, id(self), @@ -300,25 +302,45 @@ def track_transform_meta( return out_obj def check_transforms_match(self, transform: Mapping) -> None: - """Check transforms are of same instance.""" - xform_id = transform.get(TraceKeys.ID, "") - if xform_id == id(self): - return - # TraceKeys.NONE to skip the id check - if xform_id == TraceKeys.NONE: + """Check whether a traced transform entry matches this transform. + + When multiprocessing uses ``spawn``, transform instances are recreated, + so matching can fall back to the transform class name instead of the + original instance ID. + """ + if self._transforms_match(transform): return + + xform_id = transform.get(TraceKeys.ID, "") xform_name = transform.get(TraceKeys.CLASS_NAME, "") warning_msg = transform.get(TraceKeys.EXTRA_INFO, {}).get("warn") if warning_msg: warnings.warn(warning_msg) - # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) - if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: - return raise RuntimeError( f"Error {self.__class__.__name__} getting the most recently " f"applied invertible transform {xform_name} {xform_id} != {id(self)}." ) + def _transforms_match(self, transform: Mapping) -> bool: + """Return whether a traced transform entry matches this transform. + + Matching succeeds when the traced ID matches this instance, when the ID + check is explicitly disabled with ``TraceKeys.NONE``, or when + multiprocessing uses ``spawn`` and the traced class name matches this + transform class. + """ + xform_id = transform.get(TraceKeys.ID, "") + if xform_id == id(self): + return True + # TraceKeys.NONE to skip the id check + if xform_id == TraceKeys.NONE: + return True + xform_name = transform.get(TraceKeys.CLASS_NAME, "") + # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) + if torch.multiprocessing.get_start_method(allow_none=True) == "spawn" and xform_name == self.__class__.__name__: + return True + return False + def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False): """ Get most recent matching transform for the current class from the sequence of applied operations. @@ -350,10 +372,16 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr if not all_transforms: raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'") + match_idx = len(all_transforms) - 1 if check: - self.check_transforms_match(all_transforms[-1]) + for idx in range(len(all_transforms) - 1, -1, -1): + if self._transforms_match(all_transforms[idx]): + match_idx = idx + break + else: + self.check_transforms_match(all_transforms[-1]) - return all_transforms.pop(-1) if pop else all_transforms[-1] + return all_transforms.pop(match_idx) if pop else all_transforms[match_idx] def pop_transform(self, data, key: Hashable = None, check: bool = True): """ diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 0628a7fbc4..f0c1d1949d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -11,6 +11,7 @@ """ A collection of "vanilla" transforms for IO functions. """ + from __future__ import annotations import inspect diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 3dc7897feb..7df6e2c5ef 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -702,7 +702,7 @@ def __init__( # if the root log level is higher than INFO, set a separate stream handler to record console = logging.StreamHandler(sys.stdout) console.setLevel(logging.INFO) - console.is_data_stats_handler = True # type:ignore[attr-defined] + console.is_data_stats_handler = True # type: ignore[attr-defined] _logger.addHandler(console) def __call__( diff --git a/tests/integration/test_loader_semaphore.py b/tests/integration/test_loader_semaphore.py index 78baedc264..c32bcb0b8b 100644 --- a/tests/integration/test_loader_semaphore.py +++ b/tests/integration/test_loader_semaphore.py @@ -10,6 +10,7 @@ # limitations under the License. """this test should not generate errors or UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores""" + from __future__ import annotations import multiprocessing as mp diff --git a/tests/profile_subclass/profiling.py b/tests/profile_subclass/profiling.py index 18aecea2fb..6106259526 100644 --- a/tests/profile_subclass/profiling.py +++ b/tests/profile_subclass/profiling.py @@ -12,6 +12,7 @@ Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark """ + from __future__ import annotations import argparse diff --git a/tests/profile_subclass/pyspy_profiling.py b/tests/profile_subclass/pyspy_profiling.py index fac425f577..671dc74c01 100644 --- a/tests/profile_subclass/pyspy_profiling.py +++ b/tests/profile_subclass/pyspy_profiling.py @@ -12,6 +12,7 @@ To be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark """ + from __future__ import annotations import argparse diff --git a/tests/transforms/croppad/test_pad_nd_dtypes.py b/tests/transforms/croppad/test_pad_nd_dtypes.py index 7fa633b8aa..a3f5f93a2d 100644 --- a/tests/transforms/croppad/test_pad_nd_dtypes.py +++ b/tests/transforms/croppad/test_pad_nd_dtypes.py @@ -12,6 +12,7 @@ Tests for pad_nd dtype support and backend selection. Validates PyTorch padding preference and NumPy fallback behavior. """ + from __future__ import annotations import unittest diff --git a/tests/transforms/inverse/test_invertd.py b/tests/transforms/inverse/test_invertd.py index 2b5e9da85d..eaed24e15d 100644 --- a/tests/transforms/inverse/test_invertd.py +++ b/tests/transforms/inverse/test_invertd.py @@ -13,11 +13,12 @@ import sys import unittest +from unittest.mock import patch import numpy as np import torch -from monai.data import DataLoader, Dataset, create_test_image_3d, decollate_batch +from monai.data import DataLoader, Dataset, MetaTensor, create_test_image_2d, create_test_image_3d, decollate_batch from monai.transforms import ( CastToTyped, Compose, @@ -36,7 +37,10 @@ ScaleIntensityd, Spacingd, ) -from monai.utils import set_determinism +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform +from monai.transforms.utility.dictionary import Lambdad +from monai.utils import TraceKeys, set_determinism from tests.test_utils import assert_allclose, make_nifti_image KEYS = ["image", "label"] @@ -137,6 +141,236 @@ def test_invert(self): set_determinism(seed=None) + def test_invertd_with_postprocessing_transforms(self): + """Test that Invertd ignores unrelated trailing transforms while inverting.""" + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # Preprocessing pipeline + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) + + # Postprocessing with Lambdad before Invertd + # Previously this would raise RuntimeError about transform ID mismatch + postprocessing = Compose( + [ + Lambdad(key, func=lambda x: x), # Should be ignored during inversion + Invertd(key, transform=preprocessing, orig_keys=key), + ] + ) + + # Apply transforms + item = {key: img} + pre = preprocessing(item) + + # This should NOT raise an error (was failing before the fix). + # Any exception here means the bug is not fixed. + post = postprocessing(pre) + self.assertIsNotNone(post) + self.assertIn(key, post) + self.assertTupleEqual(tuple(post[key].shape), (1, 60, 60)) + self.assertEqual(len(post[key].applied_operations), 1) + self.assertEqual(post[key].applied_operations[0][TraceKeys.CLASS_NAME], "Lambda") + + def test_invertd_multiple_pipelines(self): + """Test that Invertd correctly handles multiple independent preprocessing pipelines.""" + img1, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img1 = MetaTensor(img1, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + img2, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img2 = MetaTensor(img2, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + + # Two different preprocessing pipelines + preprocessing1 = Compose([EnsureChannelFirstd("image1"), Spacingd("image1", pixdim=[2.0, 2.0])]) + + preprocessing2 = Compose([EnsureChannelFirstd("image2"), Spacingd("image2", pixdim=[1.5, 1.5])]) + + # Postprocessing that inverts both + postprocessing = Compose( + [ + Lambdad(["image1", "image2"], func=lambda x: x), + Invertd("image1", transform=preprocessing1, orig_keys="image1"), + Invertd("image2", transform=preprocessing2, orig_keys="image2"), + ] + ) + + # Apply transforms + item = {"image1": img1, "image2": img2} + pre1 = preprocessing1(item) + pre2 = preprocessing2(pre1) + + # Should not raise error - each Invertd should only invert its own pipeline + post = postprocessing(pre2) + self.assertIn("image1", post) + self.assertIn("image2", post) + + def test_invertd_multiple_postprocessing_transforms(self): + """Test Invertd with multiple invertible transforms in postprocessing before Invertd.""" + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) + + # Multiple transforms in postprocessing before Invertd + postprocessing = Compose( + [ + Lambdad(key, func=lambda x: x * 2), + Lambdad(key, func=lambda x: x + 1), + Lambdad(key, func=lambda x: x - 1), + Invertd(key, transform=preprocessing, orig_keys=key), + ] + ) + + item = {key: img} + pre = preprocessing(item) + post = postprocessing(pre) + + self.assertIsNotNone(post) + self.assertIn(key, post) + + def test_invertd_preserves_unrelated_postprocessing_history(self): + """Test that Invertd only removes the transforms it actually inverts.""" + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) + postprocessing = Compose([Lambdad(key, func=lambda x: x), Lambdad(key, func=lambda x: x)]) + + item = {key: img} + pre = preprocessing(item) + post = postprocessing(pre) + + with patch("torch.multiprocessing.get_start_method", return_value=None): + inverter = Invertd(key, transform=preprocessing, orig_keys=key) + inverted = inverter(post) + + self.assertTupleEqual(tuple(inverted[key].shape), (1, 60, 60)) + self.assertEqual([op[TraceKeys.CLASS_NAME] for op in inverted[key].applied_operations], ["Lambda", "Lambda"]) + + def test_invertd_preserves_same_class_postprocessing_history(self): + """Test MetaTensor inversion when trailing history contains the same transform class.""" + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) + postprocessing = Compose([Spacingd(key, pixdim=[1.5, 1.5])]) + + item = {key: img} + pre = preprocessing(item) + post = postprocessing(pre) + + with patch("torch.multiprocessing.get_start_method", return_value=None): + inverter = Invertd(key, transform=preprocessing, orig_keys=key) + inverted = inverter(post) + + self.assertTupleEqual(tuple(inverted[key].shape), (1, 60, 60)) + self.assertEqual(len(inverted[key].applied_operations), 1) + self.assertEqual(inverted[key].applied_operations[0][TraceKeys.CLASS_NAME], "SpatialResample") + + def test_invertd_ignores_unrelated_trace_key_history(self): + """Test trace-key inversion when unrelated invertible transforms trail the target history.""" + + class _IdentityMapInvertible(MapTransform, InvertibleTransform): + def __init__(self, keys): + super().__init__(keys) + + def __call__(self, data): + d = dict(data) + self.push_transform(d, key=self.keys[0]) + return d + + def inverse(self, data): + d = dict(data) + self.pop_transform(d, key=self.keys[0]) + return d + + key = "image" + target_transform = _IdentityMapInvertible(key) + other_transform = _IdentityMapInvertible(key) + item = {key: torch.zeros((1, 8, 8), dtype=torch.float32)} + item = target_transform(item) + item = other_transform(item) + + with patch("torch.multiprocessing.get_start_method", return_value=None): + inverter = Invertd(key, transform=target_transform, orig_keys=key, nearest_interp=False) + inverted = inverter(item) + + trace_key = InvertibleTransform.trace_key(key) + self.assertEqual(len(inverted[trace_key]), 1) + self.assertEqual(inverted[trace_key][0][TraceKeys.ID], id(other_transform)) + + def test_compose_inverse(self): + """Test that Compose.inverse() works correctly on its own transform history.""" + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # Create a preprocessing pipeline + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) + + # Apply preprocessing + item = {key: img} + pre = preprocessing(item) + + # Call inverse() directly on the Compose object + inverted = preprocessing.inverse(pre) + + # Should successfully invert + self.assertIsNotNone(inverted) + self.assertIn(key, inverted) + # Shape should be restored after inversion + self.assertEqual(inverted[key].shape[1:], img.shape) + + def test_compose_inverse_with_postprocessing_transforms(self): + """Test Compose.inverse() when unrelated postprocessing transforms trail the target history.""" + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # Preprocessing pipeline + preprocessing = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) + + # Postprocessing pipeline whose transforms should remain after the preprocessing inverse + postprocessing = Compose([Lambdad(key, func=lambda x: x)]) + + # Apply both pipelines + item = {key: img} + pre = preprocessing(item) + post = postprocessing(pre) + + # Calling inverse() directly should restore the preprocessing changes without consuming the + # unrelated postprocessing transform entry. + inverted = preprocessing.inverse(post) + self.assertTupleEqual(tuple(inverted[key].shape), (1, 60, 60)) + self.assertEqual(len(inverted[key].applied_operations), 1) + self.assertEqual(inverted[key].applied_operations[0][TraceKeys.CLASS_NAME], "Lambda") + + def test_mixed_invertd_and_compose_inverse(self): + """Test using Invertd and Compose.inverse() on the same pipeline history.""" + img, _ = create_test_image_2d(60, 60, 2, 10, num_seg_classes=2) + img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0, 1.0]}) + key = "image" + + # First pipeline + pipeline1 = Compose([EnsureChannelFirstd(key), Spacingd(key, pixdim=[2.0, 2.0])]) + + # Apply first pipeline + item = {key: img} + result1 = pipeline1(item) + + # Use Compose.inverse() directly - should work fine + inverted1 = pipeline1.inverse(result1) + self.assertIsNotNone(inverted1) + self.assertEqual(inverted1[key].shape[1:], img.shape) + + # Now apply pipeline again and use Invertd + result2 = pipeline1(item) + inverter = Invertd(key, transform=pipeline1, orig_keys=key) + inverted2 = inverter(result2) + self.assertIsNotNone(inverted2) + if __name__ == "__main__": unittest.main() diff --git a/versioneer.py b/versioneer.py index a06587fc3f..5d0a606c91 100644 --- a/versioneer.py +++ b/versioneer.py @@ -273,6 +273,7 @@ [travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer """ + # pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring # pylint:disable=missing-class-docstring,too-many-branches,too-many-statements # pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error