diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index 4408d602bd..d48d5fc878 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -468,21 +468,28 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe """ d: dict[Hashable, MetaTensor] = dict(data) start = time.time() - if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda": - using_cuda = True - else: - using_cuda = False + image_tensor = d[self.image_key] + label_tensor = d[self.label_key] + using_cuda = any( + isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor) + ) restore_grad_state = torch.is_grad_enabled() torch.set_grad_enabled(False) - ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore - ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D) + if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance( + label_tensor, (MetaTensor, torch.Tensor) + ): + if label_tensor.device != image_tensor.device: + label_tensor = label_tensor.to(image_tensor.device) # type: ignore + + ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore + ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D) if ndas_label.shape != ndas[0].shape: raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas] - nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds] + nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] unique_label = unique(ndas_label) if isinstance(ndas_label, (MetaTensor, torch.Tensor)): diff --git a/requirements-dev.txt b/requirements-dev.txt index 3e9189a1c3..16d91a39a7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,9 +14,9 @@ mccabe pep8-naming pycodestyle pyflakes -black>=25.1.0 -isort>=5.1, !=6.0.0 -ruff +black==25.1.0 +isort>=5.1, <6, !=6.0.0 +ruff>=0.14.11,<0.15 pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows" types-setuptools mypy>=1.5.0, <1.12.0 diff --git a/tests/apps/test_auto3dseg.py b/tests/apps/test_auto3dseg.py index 6c0d8123d7..2159265873 100644 --- a/tests/apps/test_auto3dseg.py +++ b/tests/apps/test_auto3dseg.py @@ -53,7 +53,7 @@ SqueezeDimd, ToDeviced, ) -from monai.utils.enums import DataStatsKeys +from monai.utils.enums import DataStatsKeys, LabelStatsKeys from tests.test_utils import skip_if_no_cuda device = "cpu" @@ -78,6 +78,13 @@ SIM_GPU_TEST_CASES = [[{"sim_dim": (32, 32, 32), "label_key": "label"}], [{"sim_dim": (32, 32, 32), "label_key": None}]] +LABEL_STATS_DEVICE_TEST_CASES = [ + [{"image_device": "cpu", "label_device": "cpu", "image_meta": False}], + [{"image_device": "cuda", "label_device": "cuda", "image_meta": True}], + [{"image_device": "cpu", "label_device": "cuda", "image_meta": True}], + [{"image_device": "cuda", "label_device": "cpu", "image_meta": False}], +] + def create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None: """ @@ -360,6 +367,50 @@ def test_label_stats_case_analyzer(self): report_format = analyzer.get_report_format() assert verify_report_format(d["label_stats"], report_format) + @parameterized.expand(LABEL_STATS_DEVICE_TEST_CASES) + def test_label_stats_mixed_device_analyzer(self, input_params): + image_device = torch.device(input_params["image_device"]) + label_device = torch.device(input_params["label_device"]) + + if (image_device.type == "cuda" or label_device.type == "cuda") and not torch.cuda.is_available(): + self.skipTest("CUDA is not available for mixed-device LabelStats tests.") + + analyzer = LabelStats(image_key="image", label_key="label") + + image_tensor = torch.tensor( + [ + [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], + [[[11.0, 12.0], [13.0, 14.0]], [[15.0, 16.0], [17.0, 18.0]]], + ], + dtype=torch.float32, + ).to(image_device) + label_tensor = torch.tensor([[[0, 1], [1, 0]], [[0, 1], [0, 1]]], dtype=torch.int64).to(label_device) + + if input_params["image_meta"]: + image_tensor = MetaTensor(image_tensor) + label_tensor = MetaTensor(label_tensor) + + result = analyzer({"image": image_tensor, "label": label_tensor}) + report = result["label_stats"] + + assert verify_report_format(report, analyzer.get_report_format()) + assert report[LabelStatsKeys.LABEL_UID] == [0, 1] + + label_stats = report[LabelStatsKeys.LABEL] + self.assertAlmostEqual(label_stats[0][LabelStatsKeys.PIXEL_PCT], 0.5) + self.assertAlmostEqual(label_stats[1][LabelStatsKeys.PIXEL_PCT], 0.5) + + label0_intensity = label_stats[0][LabelStatsKeys.IMAGE_INTST] + label1_intensity = label_stats[1][LabelStatsKeys.IMAGE_INTST] + self.assertAlmostEqual(label0_intensity[0]["mean"], 4.25) + self.assertAlmostEqual(label1_intensity[0]["mean"], 4.75) + self.assertAlmostEqual(label0_intensity[1]["mean"], 14.25) + self.assertAlmostEqual(label1_intensity[1]["mean"], 14.75) + + foreground_stats = report[LabelStatsKeys.IMAGE_INTST] + self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75) + self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75) + def test_filename_case_analyzer(self): analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH) analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)