diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py index c48c77cf98..5d87000d64 100644 --- a/monai/networks/layers/filtering.py +++ b/monai/networks/layers/filtering.py @@ -221,7 +221,8 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)" + f"or be a single float value ({spatial_sigma=})." ) # Register sigmas as trainable parameters. @@ -231,6 +232,10 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) def forward(self, input_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " @@ -239,24 +244,27 @@ def forward(self, input_tensor): ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError( + f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`." + ) prediction = TrainableBilateralFilterFunction.apply( input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction @@ -389,7 +397,8 @@ def __init__(self, spatial_sigma, color_sigma): self.len_spatial_sigma = 3 else: raise ValueError( - f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." + f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)\n" + f"or be a single float value ({spatial_sigma=})." ) # Register sigmas as trainable parameters. @@ -399,39 +408,45 @@ def __init__(self, spatial_sigma, color_sigma): self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma)) def forward(self, input_tensor, guidance_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( - f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " + f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. " "Please use multiple parallel filter layers if you want " "to filter multiple channels." ) if input_tensor.shape != guidance_tensor.shape: raise ValueError( - "Shape of input image must equal shape of guidance image." - f"Got {input_tensor.shape} and {guidance_tensor.shape}." + f"Shape of input image must equal shape of guidance image.Got {input_tensor.shape} and {guidance_tensor.shape}." ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError( + f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`." + ) prediction = TrainableJointBilateralFilterFunction.apply( input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction