From 4924aa6200e215705c0570370acb83c21cd93749 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo Date: Sun, 1 Feb 2026 14:33:47 +0100 Subject: [PATCH 1/3] Fix TrainableBilateralFilter 3D input validation (#7444) - Fix dimension comparison to use spatial dims instead of total dims - Add validation for minimum input dimensions - Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma) - Move spatial dimension validation before unsqueeze operations The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected. Fixes #7444 Signed-off-by: Abdoulaye Diallo --- monai/networks/layers/filtering.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index c48c77cf98..2b46ce1b6e 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -221,7 +221,7 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." ) # Register sigmas as trainable parameters. @@ -231,6 +231,10 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) def forward(self, input_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " @@ -239,24 +243,25 @@ def forward(self, input_tensor): ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") prediction = TrainableBilateralFilterFunction.apply( input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction @@ -389,7 +394,7 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." ) # Register sigmas as trainable parameters. From 8264cac73b83ce805898f450e6ad1fbcb8b504dd Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo Date: Wed, 4 Mar 2026 13:32:58 +0100 Subject: [PATCH 2/3] fix: apply same dimension handling fixes to TrainableJointBilateralFilter --- monai/networks/layers/filtering.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 2b46ce1b6e..53b1d7a83a 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -221,7 +221,7 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2 or 3)." ) # Register sigmas as trainable parameters. @@ -394,7 +394,7 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2, or 3)." ) # Register sigmas as trainable parameters. @@ -402,11 +402,15 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_y = torch.nn.Parameter(torch.tensor(spatial_sigma[1])) self.sigma_z = torch.nn.Parameter(torch.tensor(spatial_sigma[2])) self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) - + def forward(self, input_tensor, guidance_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( - f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " + f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. " "Please use multiple parallel filter layers if you want " "to filter multiple channels." ) @@ -417,26 +421,28 @@ def forward(self, input_tensor, guidance_tensor): ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") prediction = TrainableJointBilateralFilterFunction.apply( input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction + From a8c7a0ed01964618066e6f64ef544d1d5d7c5055 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:35:07 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/layers/filtering.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 53b1d7a83a..56676b4f93 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -402,7 +402,7 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_y = torch.nn.Parameter(torch.tensor(spatial_sigma[1])) self.sigma_z = torch.nn.Parameter(torch.tensor(spatial_sigma[2])) self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) - + def forward(self, input_tensor, guidance_tensor): if len(input_tensor.shape) < 3: raise ValueError( @@ -445,4 +445,3 @@ def forward(self, input_tensor, guidance_tensor): prediction = prediction.squeeze(4) return prediction -