From f6673faef3559b950f5bb8b2e2ee26f4ec91c630 Mon Sep 17 00:00:00 2001 From: Aaron Ponti Date: Thu, 5 Feb 2026 15:14:28 +0100 Subject: [PATCH 1/6] Fix using a non-tuple sequence for multidimensional indexing that will result in error in pytorch 2.9. --- monai/inferers/utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 766486a807..5553cdcacf 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -223,14 +223,22 @@ def sliding_window_inference( for idx in slice_range ] if sw_batch_size > 1: - win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + win_data = torch.cat( + [inputs[tuple(win_slice) if isinstance(win_slice, list) else win_slice] for win_slice in unravel_slice] + ).to(sw_device) if condition is not None: - win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device) + win_condition = torch.cat( + [condition[tuple(win_slice) if isinstance(win_slice, list) else win_slice] for win_slice in + unravel_slice] + ).to(sw_device) kwargs["condition"] = win_condition else: - win_data = inputs[unravel_slice[0]].to(sw_device) + s0 = unravel_slice[0] + s0_idx = tuple(s0) if isinstance(s0, list) else s0 + + win_data = inputs[s0_idx].to(sw_device) if condition is not None: - win_condition = condition[unravel_slice[0]].to(sw_device) + win_condition = condition[s0_idx].to(sw_device) kwargs["condition"] = win_condition if with_coord: @@ -257,7 +265,7 @@ def sliding_window_inference( offset = s[buffer_dim + 2].start - c_start s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) s[0] = slice(0, 1) - sw_device_buffer[0][s] += p * w_t + sw_device_buffer[0][tuple(s) if isinstance(s, list) else s] += p * w_t b_i += len(unravel_slice) if b_i < b_slices[b_s][0]: continue @@ -288,10 +296,11 @@ def sliding_window_inference( o_slice[buffer_dim + 2] = slice(c_start, c_end) img_b = b_s // n_per_batch # image batch index o_slice[0] = slice(img_b, img_b + 1) + o_slice_idx = tuple(o_slice) if isinstance(o_slice, list) else o_slice if non_blocking: - output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking) + output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking) else: - output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) + output_image_list[0][o_slice_idx] += sw_device_buffer[0].to(device=device) else: sw_device_buffer[ss] *= w_t sw_device_buffer[ss] = sw_device_buffer[ss].to(device) @@ -367,7 +376,7 @@ def _compute_coords(coords, z_scale, out, patch): idx_zm[axis] = slice( int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2]) ) - out[idx_zm] += p + out[tuple(idx_zm)] += p def _get_scan_interval( From 922dfff4db19ec733c4c48e8bd62ff6f387fed65 Mon Sep 17 00:00:00 2001 From: Aaron Ponti Date: Thu, 5 Feb 2026 18:35:48 +0100 Subject: [PATCH 2/6] Fix formatting. --- monai/inferers/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 5553cdcacf..2ef754ddac 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -228,8 +228,10 @@ def sliding_window_inference( ).to(sw_device) if condition is not None: win_condition = torch.cat( - [condition[tuple(win_slice) if isinstance(win_slice, list) else win_slice] for win_slice in - unravel_slice] + [ + condition[tuple(win_slice) if isinstance(win_slice, list) else win_slice] + for win_slice in unravel_slice + ] ).to(sw_device) kwargs["condition"] = win_condition else: From 57ffb77a1419ae68182d6bf59f32b92686226134 Mon Sep 17 00:00:00 2001 From: Aaron Ponti Date: Sat, 7 Feb 2026 13:40:08 +0100 Subject: [PATCH 3/6] DCO Remediation Commit for Aaron Ponti I, Aaron Ponti , hereby add my Signed-off-by to this commit: f6673faef3559b950f5bb8b2e2ee26f4ec91c630 I, Aaron Ponti , hereby add my Signed-off-by to this commit: 922dfff4db19ec733c4c48e8bd62ff6f387fed65 Signed-off-by: Aaron Ponti From 939a5bbfc5dea9fa8db28d18b269e19b33baa405 Mon Sep 17 00:00:00 2001 From: Aaron Ponti Date: Mon, 2 Mar 2026 10:46:13 +0100 Subject: [PATCH 4/6] Use monai.utils.ensure_tuple instead of explicit casting. Signed-off-by: Aaron Ponti --- monai/inferers/utils.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 2ef754ddac..f3c9746477 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -223,20 +223,15 @@ def sliding_window_inference( for idx in slice_range ] if sw_batch_size > 1: - win_data = torch.cat( - [inputs[tuple(win_slice) if isinstance(win_slice, list) else win_slice] for win_slice in unravel_slice] - ).to(sw_device) + win_data = torch.cat([inputs[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(sw_device) if condition is not None: - win_condition = torch.cat( - [ - condition[tuple(win_slice) if isinstance(win_slice, list) else win_slice] - for win_slice in unravel_slice - ] - ).to(sw_device) + win_condition = torch.cat([condition[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to( + sw_device + ) kwargs["condition"] = win_condition else: s0 = unravel_slice[0] - s0_idx = tuple(s0) if isinstance(s0, list) else s0 + s0_idx = ensure_tuple(s0) win_data = inputs[s0_idx].to(sw_device) if condition is not None: @@ -267,7 +262,7 @@ def sliding_window_inference( offset = s[buffer_dim + 2].start - c_start s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) s[0] = slice(0, 1) - sw_device_buffer[0][tuple(s) if isinstance(s, list) else s] += p * w_t + sw_device_buffer[0][ensure_tuple(s)] += p * w_t b_i += len(unravel_slice) if b_i < b_slices[b_s][0]: continue @@ -298,7 +293,7 @@ def sliding_window_inference( o_slice[buffer_dim + 2] = slice(c_start, c_end) img_b = b_s // n_per_batch # image batch index o_slice[0] = slice(img_b, img_b + 1) - o_slice_idx = tuple(o_slice) if isinstance(o_slice, list) else o_slice + o_slice_idx = ensure_tuple(o_slice) if non_blocking: output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking) else: @@ -378,7 +373,7 @@ def _compute_coords(coords, z_scale, out, patch): idx_zm[axis] = slice( int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2]) ) - out[tuple(idx_zm)] += p + out[ensure_tuple(idx_zm)] += p def _get_scan_interval( From 95dac4b5619917df7e5604f199b239e06ce43329 Mon Sep 17 00:00:00 2001 From: Aaron Ponti Date: Mon, 2 Mar 2026 12:21:17 +0100 Subject: [PATCH 5/6] Test different sw_batch_sizes and buffered vs. non-buffered flows in monai.inferers.sliding_window_inference. Signed-off-by: Aaron Ponti --- .../inferers/test_sliding_window_inference.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/inferers/test_sliding_window_inference.py b/tests/inferers/test_sliding_window_inference.py index 8700c4fcd0..3171c2a030 100644 --- a/tests/inferers/test_sliding_window_inference.py +++ b/tests/inferers/test_sliding_window_inference.py @@ -20,6 +20,7 @@ from monai.data.utils import list_data_collate from monai.inferers import SlidingWindowInferer, SlidingWindowInfererAdapt, sliding_window_inference +from monai.inferers.utils import _compute_coords from monai.utils import optional_import from tests.test_utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick @@ -704,6 +705,53 @@ def compute_dict(data, condition): for rr, _ in zip(result_dict, expected_dict): np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4) + @parameterized.expand([(1,), (4,)]) + def test_conditioned_branches_and_buffered_parity(self, sw_batch_size): + inputs = torch.arange(1 * 1 * 10 * 8, dtype=torch.float).reshape(1, 1, 10, 8) + condition = inputs + 100.0 + roi_shape = (4, 4) + + def compute(data, condition): + self.assertEqual(data.device.type, "cpu") + self.assertEqual(condition.device.type, "cpu") + torch.testing.assert_close(condition - data, torch.full_like(data, 100.0)) + return data + condition + + # Non-buffered flow. + result_non_buffered = sliding_window_inference( + inputs, roi_shape, sw_batch_size, compute, overlap=0.5, mode="constant", condition=condition + ) + # Buffered flow; should match the non-buffered output. + result_buffered = sliding_window_inference( + inputs, + roi_shape, + sw_batch_size, + compute, + overlap=0.5, + mode="constant", + condition=condition, + buffer_steps=2, + buffer_dim=0, + ) + + expected = inputs + condition + torch.testing.assert_close(result_non_buffered, expected) + torch.testing.assert_close(result_buffered, expected) + torch.testing.assert_close(result_buffered, result_non_buffered) + + +class TestSlidingWindowUtils(unittest.TestCase): + def test_compute_coords_accepts_list_indices(self): + out = torch.zeros((1, 1, 12, 12), dtype=torch.float) + patch = torch.arange(16, dtype=torch.float).reshape(1, 1, 4, 4) + coords = [[slice(0, 1), slice(None), slice(1, 3), slice(2, 4)]] + + _compute_coords(coords=coords, z_scale=[2.0, 2.0], out=out, patch=patch) + + expected = torch.zeros_like(out) + expected[0, 0, 2:6, 4:8] = patch[0, 0] + torch.testing.assert_close(out, expected) + if __name__ == "__main__": unittest.main() From 841bb45b657b836dd1a96f0ecc1d2066fb44772a Mon Sep 17 00:00:00 2001 From: Aaron Ponti Date: Mon, 2 Mar 2026 18:33:16 +0100 Subject: [PATCH 6/6] Add missing docstrings for test_conditioned_branches_and_buffered_parity, (nested) compute, TestSlidingWindowUtils, test_compute_coords_accepts_list_indices. Signed-off-by: Aaron Ponti --- .../inferers/test_sliding_window_inference.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/inferers/test_sliding_window_inference.py b/tests/inferers/test_sliding_window_inference.py index 3171c2a030..5a624c787f 100644 --- a/tests/inferers/test_sliding_window_inference.py +++ b/tests/inferers/test_sliding_window_inference.py @@ -707,11 +707,34 @@ def compute_dict(data, condition): @parameterized.expand([(1,), (4,)]) def test_conditioned_branches_and_buffered_parity(self, sw_batch_size): + """Validate conditioned parity between buffered and non-buffered flows. + + Args: + sw_batch_size (int): Sliding-window batch size. + + Returns: + None. + + Raises: + AssertionError: If device, conditioning alignment, or output parity checks fail. + """ inputs = torch.arange(1 * 1 * 10 * 8, dtype=torch.float).reshape(1, 1, 10, 8) condition = inputs + 100.0 roi_shape = (4, 4) def compute(data, condition): + """Compute output for a conditioned patch. + + Args: + data (torch.Tensor): Input patch tensor. + condition (torch.Tensor): Conditioning patch tensor aligned to ``data``. + + Returns: + torch.Tensor: Element-wise ``data + condition``. + + Raises: + AssertionError: If device placement or conditioning alignment checks fail. + """ self.assertEqual(data.device.type, "cpu") self.assertEqual(condition.device.type, "cpu") torch.testing.assert_close(condition - data, torch.full_like(data, 100.0)) @@ -741,7 +764,30 @@ def compute(data, condition): class TestSlidingWindowUtils(unittest.TestCase): + """Tests for low-level sliding-window utility helpers. + + Args: + None. + + Returns: + None. + + Raises: + None. + """ + def test_compute_coords_accepts_list_indices(self): + """Ensure ``_compute_coords`` handles list-based index containers. + + Args: + None. + + Returns: + None. + + Raises: + AssertionError: If computed output placement differs from expected placement. + """ out = torch.zeros((1, 1, 12, 12), dtype=torch.float) patch = torch.arange(16, dtype=torch.float).reshape(1, 1, 4, 4) coords = [[slice(0, 1), slice(None), slice(1, 3), slice(2, 4)]]