Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,21 +468,28 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
"""
d: dict[Hashable, MetaTensor] = dict(data)
start = time.time()
if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
using_cuda = True
else:
using_cuda = False
image_tensor = d[self.image_key]
label_tensor = d[self.label_key]
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
)
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
label_tensor = label_tensor.to(image_tensor.device) # type: ignore

ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)

if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
Expand Down
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ mccabe
pep8-naming
pycodestyle
pyflakes
black>=25.1.0
isort>=5.1, !=6.0.0
ruff
black==25.1.0
isort>=5.1, <6, !=6.0.0
ruff>=0.14.11,<0.15
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
types-setuptools
mypy>=1.5.0, <1.12.0
Expand Down
53 changes: 52 additions & 1 deletion tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
SqueezeDimd,
ToDeviced,
)
from monai.utils.enums import DataStatsKeys
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
from tests.test_utils import skip_if_no_cuda

device = "cpu"
Expand All @@ -78,6 +78,13 @@

SIM_GPU_TEST_CASES = [[{"sim_dim": (32, 32, 32), "label_key": "label"}], [{"sim_dim": (32, 32, 32), "label_key": None}]]

LABEL_STATS_DEVICE_TEST_CASES = [
[{"image_device": "cpu", "label_device": "cpu", "image_meta": False}],
[{"image_device": "cuda", "label_device": "cuda", "image_meta": True}],
[{"image_device": "cpu", "label_device": "cuda", "image_meta": True}],
[{"image_device": "cuda", "label_device": "cpu", "image_meta": False}],
]


def create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None:
"""
Expand Down Expand Up @@ -360,6 +367,50 @@ def test_label_stats_case_analyzer(self):
report_format = analyzer.get_report_format()
assert verify_report_format(d["label_stats"], report_format)

@parameterized.expand(LABEL_STATS_DEVICE_TEST_CASES)
def test_label_stats_mixed_device_analyzer(self, input_params):
image_device = torch.device(input_params["image_device"])
label_device = torch.device(input_params["label_device"])

if (image_device.type == "cuda" or label_device.type == "cuda") and not torch.cuda.is_available():
self.skipTest("CUDA is not available for mixed-device LabelStats tests.")

analyzer = LabelStats(image_key="image", label_key="label")

image_tensor = torch.tensor(
[
[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
[[[11.0, 12.0], [13.0, 14.0]], [[15.0, 16.0], [17.0, 18.0]]],
],
dtype=torch.float32,
).to(image_device)
label_tensor = torch.tensor([[[0, 1], [1, 0]], [[0, 1], [0, 1]]], dtype=torch.int64).to(label_device)

if input_params["image_meta"]:
image_tensor = MetaTensor(image_tensor)
label_tensor = MetaTensor(label_tensor)

result = analyzer({"image": image_tensor, "label": label_tensor})
report = result["label_stats"]

assert verify_report_format(report, analyzer.get_report_format())
assert report[LabelStatsKeys.LABEL_UID] == [0, 1]

label_stats = report[LabelStatsKeys.LABEL]
self.assertAlmostEqual(label_stats[0][LabelStatsKeys.PIXEL_PCT], 0.5)
self.assertAlmostEqual(label_stats[1][LabelStatsKeys.PIXEL_PCT], 0.5)

label0_intensity = label_stats[0][LabelStatsKeys.IMAGE_INTST]
label1_intensity = label_stats[1][LabelStatsKeys.IMAGE_INTST]
self.assertAlmostEqual(label0_intensity[0]["mean"], 4.25)
self.assertAlmostEqual(label1_intensity[0]["mean"], 4.75)
self.assertAlmostEqual(label0_intensity[1]["mean"], 14.25)
self.assertAlmostEqual(label1_intensity[1]["mean"], 14.75)

foreground_stats = report[LabelStatsKeys.IMAGE_INTST]
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)

def test_filename_case_analyzer(self):
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)
Expand Down
Loading