From 91e7259fdb21ccd79aeab4655c0ce29809099592 Mon Sep 17 00:00:00 2001 From: Victor Fung Date: Fri, 26 Jan 2024 09:38:28 -0500 Subject: [PATCH 1/4] Update test_predict.yml --- test/configs/cpu/test_predict.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/configs/cpu/test_predict.yml b/test/configs/cpu/test_predict.yml index f6efe6f7..6387fef9 100644 --- a/test/configs/cpu/test_predict.yml +++ b/test/configs/cpu/test_predict.yml @@ -46,6 +46,7 @@ model: otf_node_attr: False # compute gradients w.r.t to positions and cell, requires otf_edge=True gradient: False + model_ensemble: 1 optim: max_epochs: 5 From b71996c46c9f193cd9e4338c8308f818fd92b8e4 Mon Sep 17 00:00:00 2001 From: Victor Fung Date: Fri, 26 Jan 2024 09:39:30 -0500 Subject: [PATCH 2/4] Update test_training.yml --- test/configs/cpu/test_training.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/configs/cpu/test_training.yml b/test/configs/cpu/test_training.yml index a4d38b74..6515e6b6 100644 --- a/test/configs/cpu/test_training.yml +++ b/test/configs/cpu/test_training.yml @@ -46,6 +46,7 @@ model: otf_node_attr: False # compute gradients w.r.t to positions and cell, requires otf_edge=True gradient: False + model_ensemble: 1 optim: max_epochs: 5 From 665b18edf3367c9238eb23fc214ddc72ef9944a7 Mon Sep 17 00:00:00 2001 From: Victor Fung Date: Fri, 26 Jan 2024 09:56:32 -0500 Subject: [PATCH 3/4] Update base_trainer.py --- matdeeplearn/trainers/base_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/matdeeplearn/trainers/base_trainer.py b/matdeeplearn/trainers/base_trainer.py index ba837e46..77c795b4 100644 --- a/matdeeplearn/trainers/base_trainer.py +++ b/matdeeplearn/trainers/base_trainer.py @@ -145,10 +145,10 @@ def from_config(cls, config): else: rank = torch.device("cuda" if torch.cuda.is_available() else "cpu") local_world_size = 1 - dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) if hasattr(config["dataset"], "src") else None + dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) if "src" in config["dataset"] else None model = cls._load_model(config["model"], config["dataset"]["preprocess_params"], dataset, local_world_size, rank) optimizer = cls._load_optimizer(config["optim"], model, local_world_size) - sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) if hasattr(config["dataset"], "src") else None + sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) if "src" in config["dataset"] else None data_loader = cls._load_dataloader( config["optim"], config["dataset"], @@ -156,7 +156,7 @@ def from_config(cls, config): sampler, config["task"]["run_mode"], config["model"] - ) if hasattr(config["dataset"], "src") else None + ) if "src" in config["dataset"] else None scheduler = cls._load_scheduler(config["optim"]["scheduler"], optimizer) loss = cls._load_loss(config["optim"]["loss"]) From a65046174a587095c2072408de7c917c9c6ae271 Mon Sep 17 00:00:00 2001 From: Victor Fung Date: Fri, 26 Jan 2024 10:04:44 -0500 Subject: [PATCH 4/4] Update utils.py --- test/scripts/cpu/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/scripts/cpu/utils.py b/test/scripts/cpu/utils.py index 9b6abc5f..fdfe92de 100644 --- a/test/scripts/cpu/utils.py +++ b/test/scripts/cpu/utils.py @@ -20,11 +20,11 @@ def trainer_property(config, train: bool): def assert_valid_predictions(trainer, load: str): try: - out = trainer.predict(loader=trainer.data_loader[load], split="predict", write_output=False) + out = trainer.predict(loader=trainer.data_loader[0][load], split="predict", write_output=False) assert isinstance(out["predict"][0][0], (floating, float, integer, int)) assert isinstance(out["ids"][0][0], str) if load != "predict_loader": assert isinstance(out["target"][0][0], (floating, float, integer, int)) except: assert False - \ No newline at end of file +