Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@
RepeatChannelD,
RepeatChannelDict,
SelectItemsd,
SelectItemsD,
SelectItemsDict,
SimulateDelayd,
SimulateDelayD,
SimulateDelayDict,
Expand Down Expand Up @@ -395,6 +397,7 @@
img_bounds,
in_bounds,
is_empty,
is_positive,
map_binary_to_indices,
map_spatial_axes,
rand_choice,
Expand Down
14 changes: 11 additions & 3 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
compute_divisible_spatial_size,
generate_pos_neg_label_crop_centers,
generate_spatial_bounding_box,
is_positive,
map_binary_to_indices,
weighted_patch_samples,
)
Expand Down Expand Up @@ -400,7 +401,14 @@ class CropForeground(Transform):
[0, 1, 3, 2, 0],
[0, 1, 2, 1, 0],
[0, 0, 0, 0, 0]]]) # 1x5x5, single channel 5x5 image
cropper = CropForeground(select_fn=lambda x: x > 1, margin=0)


def threshold_at_one(x):
# threshold at 1
return x > 1


cropper = CropForeground(select_fn=threshold_at_one, margin=0)
print(cropper(image))
[[[2, 1],
[3, 2],
Expand All @@ -410,7 +418,7 @@ class CropForeground(Transform):

def __init__(
self,
select_fn: Callable = lambda x: x > 0,
select_fn: Callable = is_positive,
channel_indices: Optional[IndexSelection] = None,
margin: Union[Sequence[int], int] = 0,
return_coords: bool = False,
Expand Down Expand Up @@ -725,7 +733,7 @@ class BoundingRect(Transform):
select_fn: function to select expected foreground, default is to select values > 0.
"""

def __init__(self, select_fn: Callable = lambda x: x > 0) -> None:
def __init__(self, select_fn: Callable = is_positive) -> None:
self.select_fn = select_fn

def __call__(self, img: np.ndarray) -> np.ndarray:
Expand Down
11 changes: 8 additions & 3 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
)
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices, weighted_patch_samples
from monai.transforms.utils import (
generate_pos_neg_label_crop_centers,
is_positive,
map_binary_to_indices,
weighted_patch_samples,
)
from monai.utils import ImageMetaKey as Key
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils.enums import InverseKeys
Expand Down Expand Up @@ -572,7 +577,7 @@ def __init__(
self,
keys: KeysCollection,
source_key: str,
select_fn: Callable = lambda x: x > 0,
select_fn: Callable = is_positive,
channel_indices: Optional[IndexSelection] = None,
margin: int = 0,
k_divisible: Union[Sequence[int], int] = 1,
Expand Down Expand Up @@ -948,7 +953,7 @@ def __init__(
self,
keys: KeysCollection,
bbox_key_postfix: str = "bbox",
select_fn: Callable = lambda x: x > 0,
select_fn: Callable = is_positive,
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@
"RandGaussianSharpenDict",
"RandHistogramShiftD",
"RandHistogramShiftDict",
"RandRicianNoiseD",
"RandRicianNoiseDict",
]


Expand Down
3 changes: 2 additions & 1 deletion monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"LabelToContour",
"MeanEnsemble",
"VoteEnsemble",
"ProbNMS",
]


Expand Down Expand Up @@ -74,7 +75,7 @@ def __call__(
softmax: whether to execute softmax function on model output before transform.
Defaults to ``self.softmax``.
other: callable function to execute other activation layers, for example:
`other = lambda x: torch.tanh(x)`. Defaults to ``self.other``.
`other = torch.tanh`. Defaults to ``self.other``.

Raises:
ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values.
Expand Down
5 changes: 4 additions & 1 deletion monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
"DecollateD",
"DecollateDict",
"Decollated",
"ProbNMSd",
"ProbNMSD",
"ProbNMSDict",
]


Expand All @@ -83,7 +86,7 @@ def __init__(
softmax: whether to execute softmax function on model output before transform.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
other: callable function to execute other activation layers,
for example: `other = lambda x: torch.tanh(x)`. it also can be a sequence of Callable, each
for example: `other = torch.tanh`. it also can be a sequence of Callable, each
element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.

Expand Down
1 change: 1 addition & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"CastToType",
"ToTensor",
"ToNumpy",
"ToPIL",
"Transpose",
"SqueezeDim",
"DataStats",
Expand Down
127 changes: 67 additions & 60 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,84 +53,90 @@
from monai.utils import ensure_tuple, ensure_tuple_rep

__all__ = [
"Identityd",
"AsChannelFirstd",
"AsChannelLastd",
"AddChannelD",
"AddChannelDict",
"AddChanneld",
"EnsureChannelFirstd",
"RepeatChanneld",
"RemoveRepeatedChanneld",
"SplitChanneld",
"CastToTyped",
"ToTensord",
"ToNumpyd",
"ToPILd",
"DeleteItemsd",
"SelectItemsd",
"SqueezeDimd",
"DataStatsd",
"SimulateDelayd",
"CopyItemsd",
"ConcatItemsd",
"Lambdad",
"RandLambdad",
"LabelToMaskd",
"FgBgToIndicesd",
"ConvertToMultiChannelBasedOnBratsClassesd",
"AddExtremePointsChannelD",
"AddExtremePointsChannelDict",
"AddExtremePointsChanneld",
"TorchVisiond",
"RandTorchVisiond",
"MapLabelValued",
"IdentityD",
"IdentityDict",
"AsChannelFirstD",
"AsChannelFirstDict",
"AsChannelFirstd",
"AsChannelLastD",
"AsChannelLastDict",
"AddChannelD",
"AddChannelDict",
"AsChannelLastd",
"CastToTypeD",
"CastToTypeDict",
"CastToTyped",
"ConcatItemsD",
"ConcatItemsDict",
"ConcatItemsd",
"ConvertToMultiChannelBasedOnBratsClassesD",
"ConvertToMultiChannelBasedOnBratsClassesDict",
"ConvertToMultiChannelBasedOnBratsClassesd",
"CopyItemsD",
"CopyItemsDict",
"CopyItemsd",
"DataStatsD",
"DataStatsDict",
"DataStatsd",
"DeleteItemsD",
"DeleteItemsDict",
"DeleteItemsd",
"EnsureChannelFirstD",
"EnsureChannelFirstDict",
"EnsureChannelFirstd",
"FgBgToIndicesD",
"FgBgToIndicesDict",
"FgBgToIndicesd",
"IdentityD",
"IdentityDict",
"Identityd",
"LabelToMaskD",
"LabelToMaskDict",
"LabelToMaskd",
"LambdaD",
"LambdaDict",
"Lambdad",
"MapLabelValueD",
"MapLabelValueDict",
"MapLabelValued",
"RandLambdaD",
"RandLambdaDict",
"RepeatChannelD",
"RepeatChannelDict",
"RandLambdad",
"RandTorchVisionD",
"RandTorchVisionDict",
"RandTorchVisiond",
"RemoveRepeatedChannelD",
"RemoveRepeatedChannelDict",
"RemoveRepeatedChanneld",
"RepeatChannelD",
"RepeatChannelDict",
"RepeatChanneld",
"SelectItemsD",
"SelectItemsDict",
"SelectItemsd",
"SimulateDelayD",
"SimulateDelayDict",
"SimulateDelayd",
"SplitChannelD",
"SplitChannelDict",
"CastToTypeD",
"CastToTypeDict",
"ToTensorD",
"ToTensorDict",
"DeleteItemsD",
"DeleteItemsDict",
"SplitChanneld",
"SqueezeDimD",
"SqueezeDimDict",
"DataStatsD",
"DataStatsDict",
"SimulateDelayD",
"SimulateDelayDict",
"CopyItemsD",
"CopyItemsDict",
"ConcatItemsD",
"ConcatItemsDict",
"LambdaD",
"LambdaDict",
"LabelToMaskD",
"LabelToMaskDict",
"FgBgToIndicesD",
"FgBgToIndicesDict",
"ConvertToMultiChannelBasedOnBratsClassesD",
"ConvertToMultiChannelBasedOnBratsClassesDict",
"AddExtremePointsChannelD",
"AddExtremePointsChannelDict",
"SqueezeDimd",
"ToNumpyD",
"ToNumpyDict",
"ToNumpyd",
"ToPILD",
"ToPILDict",
"ToPILd",
"ToTensorD",
"ToTensorDict",
"ToTensord",
"TorchVisionD",
"TorchVisionDict",
"RandTorchVisionD",
"RandTorchVisionDict",
"MapLabelValueD",
"MapLabelValueDict",
"TorchVisiond",
]


Expand Down Expand Up @@ -1062,6 +1068,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
ToNumpyD = ToNumpyDict = ToNumpyd
ToPILD = ToPILDict = ToPILd
DeleteItemsD = DeleteItemsDict = DeleteItemsd
SelectItemsD = SelectItemsDict = SelectItemsd
SqueezeDimD = SqueezeDimDict = SqueezeDimd
DataStatsD = DataStatsDict = DataStatsd
SimulateDelayD = SimulateDelayDict = SimulateDelayd
Expand Down
11 changes: 10 additions & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@
"img_bounds",
"in_bounds",
"is_empty",
"is_positive",
"zero_margins",
"rescale_array",
"rescale_instance_array",
"rescale_array_int_max",
"copypaste_arrays",
"compute_divisible_spatial_size",
"resize_center",
"map_binary_to_indices",
"weighted_patch_samples",
Expand Down Expand Up @@ -97,6 +99,13 @@ def is_empty(img: Union[np.ndarray, torch.Tensor]) -> bool:
return not (img.max() > img.min()) # use > instead of <= so that an image full of NaNs will result in True


def is_positive(img):
"""
Returns a boolean version of `img` where the positive values are converted into True, the other values are False.
"""
return img > 0


def zero_margins(img: np.ndarray, margin: int) -> bool:
"""
Returns True if the values within `margin` indices of the edges of `img` in dimensions 1 and 2 are 0.
Expand Down Expand Up @@ -526,7 +535,7 @@ def create_translate(spatial_dims: int, shift: Union[Sequence[float], float]) ->

def generate_spatial_bounding_box(
img: np.ndarray,
select_fn: Callable = lambda x: x > 0,
select_fn: Callable = is_positive,
channel_indices: Optional[IndexSelection] = None,
margin: Union[Sequence[int], int] = 0,
) -> Tuple[List[int], List[int]]:
Expand Down
1 change: 1 addition & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"Method",
"InverseKeys",
"CommonKeys",
"ForwardMode",
]


Expand Down