diff --git a/matdeeplearn/trainers/base_trainer.py b/matdeeplearn/trainers/base_trainer.py index ba837e46..77c795b4 100644 --- a/matdeeplearn/trainers/base_trainer.py +++ b/matdeeplearn/trainers/base_trainer.py @@ -145,10 +145,10 @@ def from_config(cls, config): else: rank = torch.device("cuda" if torch.cuda.is_available() else "cpu") local_world_size = 1 - dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) if hasattr(config["dataset"], "src") else None + dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) if "src" in config["dataset"] else None model = cls._load_model(config["model"], config["dataset"]["preprocess_params"], dataset, local_world_size, rank) optimizer = cls._load_optimizer(config["optim"], model, local_world_size) - sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) if hasattr(config["dataset"], "src") else None + sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) if "src" in config["dataset"] else None data_loader = cls._load_dataloader( config["optim"], config["dataset"], @@ -156,7 +156,7 @@ def from_config(cls, config): sampler, config["task"]["run_mode"], config["model"] - ) if hasattr(config["dataset"], "src") else None + ) if "src" in config["dataset"] else None scheduler = cls._load_scheduler(config["optim"]["scheduler"], optimizer) loss = cls._load_loss(config["optim"]["loss"]) diff --git a/test/configs/cpu/test_predict.yml b/test/configs/cpu/test_predict.yml index f6efe6f7..6387fef9 100644 --- a/test/configs/cpu/test_predict.yml +++ b/test/configs/cpu/test_predict.yml @@ -46,6 +46,7 @@ model: otf_node_attr: False # compute gradients w.r.t to positions and cell, requires otf_edge=True gradient: False + model_ensemble: 1 optim: max_epochs: 5 diff --git a/test/configs/cpu/test_training.yml b/test/configs/cpu/test_training.yml index a4d38b74..6515e6b6 100644 --- a/test/configs/cpu/test_training.yml +++ b/test/configs/cpu/test_training.yml @@ -46,6 +46,7 @@ model: otf_node_attr: False # compute gradients w.r.t to positions and cell, requires otf_edge=True gradient: False + model_ensemble: 1 optim: max_epochs: 5 diff --git a/test/scripts/cpu/utils.py b/test/scripts/cpu/utils.py index 9b6abc5f..fdfe92de 100644 --- a/test/scripts/cpu/utils.py +++ b/test/scripts/cpu/utils.py @@ -20,11 +20,11 @@ def trainer_property(config, train: bool): def assert_valid_predictions(trainer, load: str): try: - out = trainer.predict(loader=trainer.data_loader[load], split="predict", write_output=False) + out = trainer.predict(loader=trainer.data_loader[0][load], split="predict", write_output=False) assert isinstance(out["predict"][0][0], (floating, float, integer, int)) assert isinstance(out["ids"][0][0], str) if load != "predict_loader": assert isinstance(out["target"][0][0], (floating, float, integer, int)) except: assert False - \ No newline at end of file +