From 126d4c78828a260ce7ce72d5f1d69ad263782fd0 Mon Sep 17 00:00:00 2001 From: Victor Fung Date: Sun, 21 Jan 2024 21:40:02 -0500 Subject: [PATCH] Update ase_utils.py --- matdeeplearn/common/ase_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/matdeeplearn/common/ase_utils.py b/matdeeplearn/common/ase_utils.py index ecf2acb0..77e089ab 100644 --- a/matdeeplearn/common/ase_utils.py +++ b/matdeeplearn/common/ase_utils.py @@ -99,9 +99,9 @@ def calculate(self, atoms: Atoms, properties=implemented_properties, system_chan forces = torch.stack([entry["pos_grad"] for entry in out_list]).mean(dim=0) stresses = torch.stack([entry["cell_grad"] for entry in out_list]).mean(dim=0) - self.results['energy'] = energy.detach().cpu().numpy() - self.results['forces'] = forces.detach().cpu().numpy() - self.results['stress'] = stresses.squeeze().detach().cpu().numpy() + self.results['energy'] = energy.detach().cpu().numpy().squeeze() + self.results['forces'] = forces.detach().cpu().numpy().squeeze() + self.results['stress'] = stresses.squeeze().detach().cpu().numpy().squeeze() @staticmethod def data_to_atoms_list(data: Data) -> List[Atoms]: @@ -148,7 +148,7 @@ def _load_model(config: dict, rank: str) -> List[BaseModel]: model_config = config['model'] model_list = [] - model_name = 'matdeeplearn.models.' + model_config["name"] + model_name = model_config["name"] logging.info(f'MDLCalculator: setting up {model_name} for calculation') # Obtain node, edge, and output dimensions for model initialization for _ in range(model_config["model_ensemble"]): @@ -181,4 +181,4 @@ def _load_model(config: dict, rank: str) -> List[BaseModel]: except ValueError: logging.warning(f"MDLCalculator: No checkpoint.pt file is found for model No.{i+1}, and an untrained model is used for prediction.") - return model_list \ No newline at end of file + return model_list