Skip to content
49 changes: 32 additions & 17 deletions monai/networks/layers/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def __init__(self, spatial_sigma, color_sigma):
self.len_spatial_sigma = 3
else:
raise ValueError(
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)"
f"or be a single float value ({spatial_sigma=})."
)

# Register sigmas as trainable parameters.
Expand All @@ -231,6 +232,10 @@ def __init__(self, spatial_sigma, color_sigma):
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))

def forward(self, input_tensor):
if len(input_tensor.shape) < 3:
raise ValueError(
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
)
if input_tensor.shape[1] != 1:
raise ValueError(
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
Expand All @@ -239,24 +244,27 @@ def forward(self, input_tensor):
)

len_input = len(input_tensor.shape)
spatial_dims = len_input - 2

# C++ extension so far only supports 5-dim inputs.
if len_input == 3:
if spatial_dims == 1:
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
elif len_input == 4:
elif spatial_dims == 2:
input_tensor = input_tensor.unsqueeze(4)

if self.len_spatial_sigma != len_input:
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
if self.len_spatial_sigma != spatial_dims:
raise ValueError(
f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`."
)

prediction = TrainableBilateralFilterFunction.apply(
input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
)

# Make sure to return tensor of the same shape as the input.
if len_input == 3:
if spatial_dims == 1:
prediction = prediction.squeeze(4).squeeze(3)
elif len_input == 4:
elif spatial_dims == 2:
prediction = prediction.squeeze(4)

return prediction
Expand Down Expand Up @@ -389,7 +397,8 @@ def __init__(self, spatial_sigma, color_sigma):
self.len_spatial_sigma = 3
else:
raise ValueError(
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
f"Length of `spatial_sigma` must match number of spatial dims (1, 2 or 3)\n"
f"or be a single float value ({spatial_sigma=})."
)

# Register sigmas as trainable parameters.
Expand All @@ -399,39 +408,45 @@ def __init__(self, spatial_sigma, color_sigma):
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))

def forward(self, input_tensor, guidance_tensor):
if len(input_tensor.shape) < 3:
raise ValueError(
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
)
if input_tensor.shape[1] != 1:
raise ValueError(
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. "
"Please use multiple parallel filter layers if you want "
"to filter multiple channels."
)
if input_tensor.shape != guidance_tensor.shape:
raise ValueError(
"Shape of input image must equal shape of guidance image."
f"Got {input_tensor.shape} and {guidance_tensor.shape}."
f"Shape of input image must equal shape of guidance image.Got {input_tensor.shape} and {guidance_tensor.shape}."
)

len_input = len(input_tensor.shape)
spatial_dims = len_input - 2

# C++ extension so far only supports 5-dim inputs.
if len_input == 3:
if spatial_dims == 1:
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4)
elif len_input == 4:
elif spatial_dims == 2:
input_tensor = input_tensor.unsqueeze(4)
guidance_tensor = guidance_tensor.unsqueeze(4)

if self.len_spatial_sigma != len_input:
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
if self.len_spatial_sigma != spatial_dims:
raise ValueError(
f"Number of spatial dimensions ({spatial_dims}) must match initialized `len(spatial_sigma)`."
)

prediction = TrainableJointBilateralFilterFunction.apply(
input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
)

# Make sure to return tensor of the same shape as the input.
if len_input == 3:
if spatial_dims == 1:
prediction = prediction.squeeze(4).squeeze(3)
elif len_input == 4:
elif spatial_dims == 2:
prediction = prediction.squeeze(4)

return prediction
Loading