From 4924aa6200e215705c0570370acb83c21cd93749 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo Date: Sun, 1 Feb 2026 14:33:47 +0100 Subject: [PATCH 1/7] Fix TrainableBilateralFilter 3D input validation (#7444) - Fix dimension comparison to use spatial dims instead of total dims - Add validation for minimum input dimensions - Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma) - Move spatial dimension validation before unsqueeze operations The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected. Fixes #7444 Signed-off-by: Abdoulaye Diallo --- monai/networks/layers/filtering.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index c48c77cf98..2b46ce1b6e 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -221,7 +221,7 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." ) # Register sigmas as trainable parameters. @@ -231,6 +231,10 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) def forward(self, input_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " @@ -239,24 +243,25 @@ def forward(self, input_tensor): ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") prediction = TrainableBilateralFilterFunction.apply( input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction @@ -389,7 +394,7 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." ) # Register sigmas as trainable parameters. From fced85b32dbdcaa42eaedb2cdc37af8f0aa946a8 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo Date: Wed, 4 Mar 2026 13:32:58 +0100 Subject: [PATCH 2/7] fix: apply same dimension handling fixes to TrainableJointBilateralFilter Signed-off-by: Abdoulaye Diallo --- monai/networks/layers/filtering.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 2b46ce1b6e..249fcf2892 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -220,9 +220,7 @@ def __init__(self, spatial_sigma, color_sigma): spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] self.len_spatial_sigma = 3 else: - raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." - ) + raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2 or 3).") # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) @@ -393,9 +391,7 @@ def __init__(self, spatial_sigma, color_sigma): spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] self.len_spatial_sigma = 3 else: - raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." - ) + raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2, or 3).") # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) @@ -404,9 +400,13 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) def forward(self, input_tensor, guidance_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( - f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " + f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. " "Please use multiple parallel filter layers if you want " "to filter multiple channels." ) @@ -417,26 +417,27 @@ def forward(self, input_tensor, guidance_tensor): ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") prediction = TrainableJointBilateralFilterFunction.apply( input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction From 01851680ad9f7260206ed9ce17d76c0119744581 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo <113793273+getrichthroughcode@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:07:22 +0200 Subject: [PATCH 3/7] Update monai/networks/layers/filtering.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Abdoulaye Diallo <113793273+getrichthroughcode@users.noreply.github.com> --- monai/networks/layers/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 249fcf2892..31dc92889b 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -250,7 +250,7 @@ def forward(self, input_tensor): input_tensor = input_tensor.unsqueeze(4) if self.len_spatial_sigma != spatial_dims: - raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") + raise ValueError(f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`.") prediction = TrainableBilateralFilterFunction.apply( input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color From 5aa42f2885bff27d7cc63a26ce0efa69d63db0a7 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo <113793273+getrichthroughcode@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:07:40 +0200 Subject: [PATCH 4/7] Update monai/networks/layers/filtering.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Abdoulaye Diallo <113793273+getrichthroughcode@users.noreply.github.com> --- monai/networks/layers/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 31dc92889b..d8db766325 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -220,7 +220,7 @@ def __init__(self, spatial_sigma, color_sigma): spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] self.len_spatial_sigma = 3 else: - raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2 or 3).") + raise ValueError(f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3) or be a single float value ({spatial_sigma=}).") # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) From 6af2225b3a22470cc48309e8b1df18c6e2b71412 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo <113793273+getrichthroughcode@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:08:01 +0200 Subject: [PATCH 5/7] Update monai/networks/layers/filtering.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Abdoulaye Diallo <113793273+getrichthroughcode@users.noreply.github.com> --- monai/networks/layers/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index d8db766325..01c82d730a 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -428,7 +428,7 @@ def forward(self, input_tensor, guidance_tensor): guidance_tensor = guidance_tensor.unsqueeze(4) if self.len_spatial_sigma != spatial_dims: - raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") + raise ValueError(f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`.") prediction = TrainableJointBilateralFilterFunction.apply( input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color From aea9bcf2de2a80b5d753e9460a367ab5784e3132 Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo <113793273+getrichthroughcode@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:08:17 +0200 Subject: [PATCH 6/7] Update monai/networks/layers/filtering.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Abdoulaye Diallo <113793273+getrichthroughcode@users.noreply.github.com> --- monai/networks/layers/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 01c82d730a..14c19d4ffd 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -391,7 +391,7 @@ def __init__(self, spatial_sigma, color_sigma): spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] self.len_spatial_sigma = 3 else: - raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2, or 3).") + raise ValueError(f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3) or be a single float value ({spatial_sigma=}).") # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) From 1b80cfa1202cd7e226c4543a37a369e19787370a Mon Sep 17 00:00:00 2001 From: Abdoulaye Diallo Date: Wed, 1 Apr 2026 14:47:12 +0200 Subject: [PATCH 7/7] Update monai/networks/layers/filtering.py Signed-off-by: Abdoulaye Diallo --- monai/networks/layers/filtering.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index 14c19d4ffd..5d87000d64 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -220,7 +220,10 @@ def __init__(self, spatial_sigma, color_sigma): spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] self.len_spatial_sigma = 3 else: - raise ValueError(f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3) or be a single float value ({spatial_sigma=}).") + raise ValueError( + f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)" + f"or be a single float value ({spatial_sigma=})." + ) # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) @@ -250,7 +253,9 @@ def forward(self, input_tensor): input_tensor = input_tensor.unsqueeze(4) if self.len_spatial_sigma != spatial_dims: - raise ValueError(f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`.") + raise ValueError( + f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`." + ) prediction = TrainableBilateralFilterFunction.apply( input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color @@ -391,7 +396,10 @@ def __init__(self, spatial_sigma, color_sigma): spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]] self.len_spatial_sigma = 3 else: - raise ValueError(f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3) or be a single float value ({spatial_sigma=}).") + raise ValueError( + f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)\n" + f"or be a single float value ({spatial_sigma=})." + ) # Register sigmas as trainable parameters. self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0])) @@ -412,8 +420,7 @@ def forward(self, input_tensor, guidance_tensor): ) if input_tensor.shape != guidance_tensor.shape: raise ValueError( - "Shape of input image must equal shape of guidance image." - f"Got {input_tensor.shape} and {guidance_tensor.shape}." + f"Shape of input image must equal shape of guidance image.Got {input_tensor.shape} and {guidance_tensor.shape}." ) len_input = len(input_tensor.shape) @@ -428,7 +435,9 @@ def forward(self, input_tensor, guidance_tensor): guidance_tensor = guidance_tensor.unsqueeze(4) if self.len_spatial_sigma != spatial_dims: - raise ValueError(f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`.") + raise ValueError( + f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`." + ) prediction = TrainableJointBilateralFilterFunction.apply( input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color