Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a607e70
Enables automatic transform group tracking for inversion
Dec 1, 2025
08d0076
autofix formatting
Dec 1, 2025
24bc645
fix errors
Dec 1, 2025
4a91861
Add Google-style docstrings for transform group helpers
eclipse0922 Feb 24, 2026
f1f92e0
Assign group IDs to wrapped transforms in Compose
eclipse0922 Feb 24, 2026
d110fdc
DCO Remediation Commit for sewon jeon <irocks0922@gmail.com>
eclipse0922 Feb 24, 2026
6d8b0c5
address the comments
eclipse0922 Feb 24, 2026
399fb2b
DCO Remediation Commit for sewon.jeon <irocks0922@gmail.com>
eclipse0922 Feb 24, 2026
95222f3
Address CodeRabbit follow-up review comments
eclipse0922 Feb 24, 2026
5abaa08
DCO Remediation Commit for sewon.jeon <irocks0922@gmail.com>
eclipse0922 Feb 24, 2026
51d711b
Refactor error messages for clarity and consistency in dictionary tra…
eclipse0922 Feb 25, 2026
8b31f24
DCO Remediation Commit for sewon.jeon <irocks0922@gmail.com>
eclipse0922 Feb 25, 2026
ac85826
DCO Remediation Commit for sewon jeon <irocks0922@gmail.com>
eclipse0922 Feb 25, 2026
a1a7d58
DCO Remediation Commit for sewon.jeon <irocks0922@gmail.com>
eclipse0922 Feb 25, 2026
b2efefc
DCO Remediation Commit for sewon.jeon <sewon.jeon@connecteve.com>
Feb 25, 2026
226deb6
Apply black 25.11.0 formatting fixes
eclipse0922 Feb 25, 2026
c42261b
Fix inverse matching with mixed postprocessing history
eclipse0922 Mar 6, 2026
1c3d991
Add same-class inverse history regression test
eclipse0922 Mar 6, 2026
20b8695
Handle unset start method in inverse matching
eclipse0922 Mar 7, 2026
98a22d3
Merge branch 'dev' into fix_invertd
eclipse0922 Mar 8, 2026
1a0859c
Clarify inverse matching docs and apply formatting
eclipse0922 Mar 13, 2026
67fbf30
fix format
eclipse0922 Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/detection/transforms/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def convert_box_to_mask(
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])
# apply to global mask
slicing = [b]
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type: ignore
boxes_mask_np[tuple(slicing)] = boxes_only_mask
return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0]

Expand Down
2 changes: 1 addition & 1 deletion monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None:
raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2")
root: str
child_key: str
(root, _, child_key) = keys
root, _, child_key = keys
if root not in self.ops:
self.ops[root] = [{}]
self.ops[root][0].update({child_key: None})
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ def create_workflow(

"""
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
(workflow_name, config_file) = _pop_args(
workflow_name, config_file = _pop_args(
_args, workflow_name=ConfigWorkflow, config_file=None
) # the default workflow name is "ConfigWorkflow"
if isinstance(workflow_name, str):
Expand Down
4 changes: 2 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class DatasetFunc(Dataset):
"""

def __init__(self, data: Any, func: Callable, **kwargs) -> None:
super().__init__(data=None, transform=None) # type:ignore
super().__init__(data=None, transform=None) # type: ignore
self.src = data
self.func = func
self.kwargs = kwargs
Expand Down Expand Up @@ -1635,7 +1635,7 @@ def _cachecheck(self, item_transformed):
return (_data, _meta)
return _data
else:
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type: ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def stopping_fn_from_loss() -> Callable[[Engine], Any]:
"""

def stopping_fn(engine: Engine) -> Any:
return -engine.state.output # type:ignore
return -engine.state.output # type: ignore

return stopping_fn

Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def get_edge_surface_distance(
edges_spacing = None
if use_subvoxels:
edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape))
(edges_pred, edges_gt, *areas) = get_mask_edges(
edges_pred, edges_gt, *areas = get_mask_edges(
y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False
)
if not edges_gt.any():
Expand Down
50 changes: 39 additions & 11 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def get_transform_info(self) -> dict:
"""
Return a dictionary with the relevant information pertaining to an applied transform.
"""
self._init_trace_threadlocal()

vals = (
self.__class__.__name__,
id(self),
Expand Down Expand Up @@ -300,25 +302,45 @@ def track_transform_meta(
return out_obj

def check_transforms_match(self, transform: Mapping) -> None:
"""Check transforms are of same instance."""
xform_id = transform.get(TraceKeys.ID, "")
if xform_id == id(self):
return
# TraceKeys.NONE to skip the id check
if xform_id == TraceKeys.NONE:
"""Check whether a traced transform entry matches this transform.

When multiprocessing uses ``spawn``, transform instances are recreated,
so matching can fall back to the transform class name instead of the
original instance ID.
"""
if self._transforms_match(transform):
return

xform_id = transform.get(TraceKeys.ID, "")
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
warning_msg = transform.get(TraceKeys.EXTRA_INFO, {}).get("warn")
if warning_msg:
warnings.warn(warning_msg)
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__:
return
raise RuntimeError(
f"Error {self.__class__.__name__} getting the most recently "
f"applied invertible transform {xform_name} {xform_id} != {id(self)}."
)

def _transforms_match(self, transform: Mapping) -> bool:
"""Return whether a traced transform entry matches this transform.

Matching succeeds when the traced ID matches this instance, when the ID
check is explicitly disabled with ``TraceKeys.NONE``, or when
multiprocessing uses ``spawn`` and the traced class name matches this
transform class.
"""
xform_id = transform.get(TraceKeys.ID, "")
if xform_id == id(self):
return True
# TraceKeys.NONE to skip the id check
if xform_id == TraceKeys.NONE:
return True
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
if torch.multiprocessing.get_start_method(allow_none=True) == "spawn" and xform_name == self.__class__.__name__:
return True
return False

def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False):
"""
Get most recent matching transform for the current class from the sequence of applied operations.
Expand Down Expand Up @@ -350,10 +372,16 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
if not all_transforms:
raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'")

match_idx = len(all_transforms) - 1
if check:
self.check_transforms_match(all_transforms[-1])
for idx in range(len(all_transforms) - 1, -1, -1):
if self._transforms_match(all_transforms[idx]):
match_idx = idx
break
else:
self.check_transforms_match(all_transforms[-1])

return all_transforms.pop(-1) if pop else all_transforms[-1]
return all_transforms.pop(match_idx) if pop else all_transforms[match_idx]

def pop_transform(self, data, key: Hashable = None, check: bool = True):
"""
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""
A collection of "vanilla" transforms for IO functions.
"""

from __future__ import annotations

import inspect
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def __init__(
# if the root log level is higher than INFO, set a separate stream handler to record
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
console.is_data_stats_handler = True # type:ignore[attr-defined]
console.is_data_stats_handler = True # type: ignore[attr-defined]
_logger.addHandler(console)

def __call__(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_loader_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.
"""this test should not generate errors or
UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores"""

from __future__ import annotations

import multiprocessing as mp
Expand Down
1 change: 1 addition & 0 deletions tests/profile_subclass/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""

from __future__ import annotations

import argparse
Expand Down
1 change: 1 addition & 0 deletions tests/profile_subclass/pyspy_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
To be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""

from __future__ import annotations

import argparse
Expand Down
1 change: 1 addition & 0 deletions tests/transforms/croppad/test_pad_nd_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Tests for pad_nd dtype support and backend selection.
Validates PyTorch padding preference and NumPy fallback behavior.
"""

from __future__ import annotations

import unittest
Expand Down
Loading
Loading