From d05b7136340ef6567d334a1aba45fd3d44009725 Mon Sep 17 00:00:00 2001 From: Rithwik Seth Date: Sun, 12 Nov 2023 13:57:11 -0800 Subject: [PATCH 1/3] Added crystal graph --- configs/crystalGraphConfig.yml | 134 ++++++++++ matdeeplearn/models/crystal_graph.py | 250 +++++++++++++++++++ matdeeplearn/models/crystal_graph_multi.py | 271 +++++++++++++++++++++ 3 files changed, 655 insertions(+) create mode 100644 configs/crystalGraphConfig.yml create mode 100644 matdeeplearn/models/crystal_graph.py create mode 100644 matdeeplearn/models/crystal_graph_multi.py diff --git a/configs/crystalGraphConfig.yml b/configs/crystalGraphConfig.yml new file mode 100644 index 00000000..1b36e954 --- /dev/null +++ b/configs/crystalGraphConfig.yml @@ -0,0 +1,134 @@ +trainer: property + +task: + #run_mode: train + identifier: my_train_job + parallel: False + # If seed is not set, then it will be random every time + seed: 12345678 + # Defaults to run directory if not specified + save_dir: + # continue from a previous job + continue_job: False + # spefcify if the training state is loaded: epochs, learning rate, etc + load_training_state: False + # Path to the checkpoint.pt file + checkpoint_path: + # Whether to write predictions to csv file. E.g. ["train", "val", "test"] + write_output: [train, val, test] + # Frequency of writing to file; 0 denotes writing only at the end, 1 denotes writing every time + output_frequency: 0 + # Frequency of saving model .pt file; 0 denotes saving only at the end, 1 denotes saving every time, -1 denotes never saving; this controls both checkpoint and best_checkpoint + model_save_frequency: 0 + # Specify if labels are provided for the predict task + # labels: True + # Use amp mixed precision + use_amp: True + +model: + name: CrystalGraphMulti + # model attributes + dim1: 128 + dim2: 128 + n_conv: 4 + n_h: 3 + k: 3 + pool: "global_mean_pool" + pool_order: "early" + batch_norm: True + batch_track_stats: True + act: "relu" + dropout_rate: 0.0 + classification: False + # Compute edge indices on the fly in the model forward + otf_edge_index: False + # Compute edge attributes on the fly in the model forward + otf_edge_attr: False + # Compute node attributes on the fly in the model forward + otf_node_attr: False + # compute gradients w.r.t to positions and cell, requires otf_edge_attr=True + gradient: False + +optim: + max_epochs: 40 + max_checkpoint_epochs: 0 + lr: 0.001 + # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper + loss: + loss_type: TorchLossWrapper + loss_args: {loss_fn: l1_loss} + # gradient clipping value + clip_grad_norm: 10 + batch_size: 100 + optimizer: + optimizer_type: AdamW + optimizer_args: {} + scheduler: + scheduler_type: ReduceLROnPlateau + scheduler_args: {mode: min, factor: 0.8, patience: 10, min_lr: 0.00001, threshold: 0.0002} + #Training print out frequency (print per n number of epochs) + verbosity: 5 + # tdqm progress bar per batch in the epoch + batch_tqdm: False + +dataset: + name: test_data + # Whether the data has already been processed and a data.pt file is present from a previous run + processed: False + # Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path} + src: data/test_data/data_graph_scalar.json + # Path to target file within data_path - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path} or left blank when the dataset is a single json file + # Example: target_path: "data/raw_graph_scalar/targets.csv" + target_path: + # Path to save processed data.pt file + pt_path: data/ + # Either "node" or "graph" level + prediction_level: graph + + transforms: + - name: GetY + args: + # index specifies the index of a target vector to predict, which is useful when there are multiple property labels for a single dataset + # For example, an index: 0 (default) will use the first entry in the target vector + # if all values are to be predicted simultaneously, then specify index: -1 + index: -1 + otf_transform: True # Optional parameter, default is True + # Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html) + data_format: json + # specify if additional attributes to be loaded into the dataset from the .json file; e.g. additional_attributes: [forces, stress] + additional_attributes: + # Print out processing info + verbose: True + # Index of target column in targets.csv + # graph specific settings + preprocess_params: + # one of mdl (minimum image convention), ocp (all neighbors included) + edge_calc_method: mdl + # determine if edges are computed, if false, then they need to be computed on the fly + preprocess_edges: True + # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly + preprocess_edge_features: True + # determine if node attributes are computed during processing, if false, then they need to be computed on the fly + preprocess_node_features: True + # distance cutoff to determine if two atoms are connected by an edge + cutoff_radius : 8.0 + # maximum number of neighbors to consider (usually an arbitrarily high number to consider all neighbors) + n_neighbors : 250 + # number of pbc offsets to consider when determining neighbors (usually not changed) + num_offsets: 2 + # dimension of node attributes + node_dim : 100 + # dimension of edge attributes + edge_dim : 50 + # whether or not to add self-loops + self_loop: True + # Method of obtaining atom dictionary: available: (onehot) + node_representation: onehot + # Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html + num_workers: 0 + # Where the dataset is loaded; either "cpu" or "cuda" + dataset_device: cpu + # Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data) + train_ratio: 0.8 + val_ratio: 0.05 + test_ratio: 0.15 diff --git a/matdeeplearn/models/crystal_graph.py b/matdeeplearn/models/crystal_graph.py new file mode 100644 index 00000000..93e78f26 --- /dev/null +++ b/matdeeplearn/models/crystal_graph.py @@ -0,0 +1,250 @@ +from __future__ import print_function, division + +import torch +import numpy as np +import torch.nn as nn +from matdeeplearn.models.base_model import BaseModel, conditional_grad +from matdeeplearn.common.registry import registry +from torch_scatter import scatter, segment_coo +from torch_geometric.nn import ( + global_mean_pool, + MessagePassing, +) +import torch.nn.functional as F +import torch_geometric +import warnings + +warnings.filterwarnings("ignore") + +#in comparison to my2, here we use softmax rather than sigmoid +#we also have the choice to use max pool rather than mean pool +class ConvLayer(MessagePassing): + """ + Convolutional operation on graphs + """ + def __init__(self, atom_fea_len, edge_fea_len): + """ + Initialize ConvLayer. + + Parameters + ---------- + + atom_fea_len: int + Number of atom hidden features. + nbr_fea_len: int + Number of bond features. + """ + super(ConvLayer, self).__init__(aggr="add", node_dim=0) + self.atom_fea_len = atom_fea_len + self.edge_fea_len = edge_fea_len + self.fc_full=nn.Linear(2*self.atom_fea_len+self.edge_fea_len, + 2*self.atom_fea_len+self.edge_fea_len) + self.fc_f = nn.Linear(2*self.atom_fea_len+self.edge_fea_len, self.atom_fea_len) + self.fc_s = nn.Linear(2*self.atom_fea_len+self.edge_fea_len, self.atom_fea_len) + self.softmax= nn.Softmax(dim=1) + self.softplus1 = nn.ReLU() + self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len+self.edge_fea_len) + self.bn2 = nn.BatchNorm1d(self.atom_fea_len) + self.dropout = nn.Dropout() + + def forward(self, x, edge_index, distances): + self.edge_attrs = distances + aggregatedMessages = self.propagate(edge_index, x=x, distances=distances) + aggregatedMessages = self.bn2(aggregatedMessages) + out = aggregatedMessages + x + return out, self.edge_attrs + + + + def message(self, x_i, x_j, distances): + #concatenate atom features, bond features, and bond distances + z = torch.cat([x_i, x_j, distances], dim=-1) + #fully connected layer + total_gated_fea = self.fc_full(z) + total_gated_fea = self.bn1(total_gated_fea) + #split into atom features, bond features, and bond distances and apply functions + nbr_filter, nbr_core, new_edge_attrs = total_gated_fea.split([self.atom_fea_len,self.atom_fea_len,self.edge_fea_len], dim=1) + #aggregate and return + self.edge_attrs += new_edge_attrs + nbr_filter = self.fc_f(z) + nbr_core = self.fc_s(z) + return self.softmax(nbr_filter) * self.softplus1(nbr_core) + + + + + +@registry.register_model("CrystalGraph") +class CrystalGraphConvNet(BaseModel): + """ + Create a crystal graph convolutional neural network for predicting total + material properties. + """ + def __init__(self, node_dim, edge_dim, output_dim, + dim1=128, n_conv=4, dim2=128, n_h=1, pool="global_mean_pool", + pool_order="early", act="relu", classification=False, **kwargs): + """ + Initialize CrystalGraphConvNet. + + Parameters + ---------- + + node_dim: int + Number of atom features in the input. + edge_dim: int + Number of bond features. + dim1: int + Number of hidden atom features in the convolutional layers + n_conv: int + Number of convolutional layers + dim2: int + Number of hidden features after pooling + n_h: int + Number of hidden layers after pooling + """ + super(CrystalGraphConvNet, self).__init__() + self.classification = classification + self.embedding = nn.Linear(node_dim, dim1) + self.convs = nn.ModuleList([ConvLayer(atom_fea_len=dim1, + edge_fea_len=edge_dim) + for _ in range(n_conv)]) + self.conv_to_fc = nn.Linear(dim1, dim2) + self.conv_to_fc_softplus = nn.ReLU() + self.output_softplus= nn.ReLU() + if n_h > 1: + self.fcs = nn.ModuleList([nn.Linear(dim2, dim2) + for _ in range(n_h-1)]) + if self.classification: + self.fc_out = nn.Linear(dim2, 2) + else: + self.fc_out = nn.Linear(dim2, output_dim) + if self.classification: + self.logsoftmax = nn.LogSoftmax(dim=1) + self.dropout = nn.Dropout() + + self.pool = pool + self.act = act + self.pool_order = pool_order + + def forward(self, data): + + output = {} + out = self._forward(data) + output["output"] = out + + if self.gradient == True and out.requires_grad == True: + if self.gradient_method == "conventional": + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.cell], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + #For calculation of stress, see https://github.com/mir-group/nequip/blob/main/nequip/nn/_grad_output.py + #Originally from: https://github.com/atomistic-machine-learning/schnetpack/issues/165 + elif self.gradient_method == "nequip": + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.displacement], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + + output["pos_grad"] = forces + output["cell_grad"] = stress + else: + output["pos_grad"] = None + output["cell_grad"] = None + + return output + @conditional_grad(torch.enable_grad()) + def _forward(self, data): + """ + Forward pass + + data: graph features + + Parameters + ---------- + + data: + data.x: node features + shape = (N, node_dim) + data.edge_index: list of edges + shape = (2, E) + data.edge_attr: edge attributes (distances) + shape = (E, edge_dim) + data.batch: crystal id for each node + shape = (N, ) + + Returns + ------- + + prediction: graph predictions + shape = (batch_size, ) + """ + if self.otf_edge_index == True: + #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + else: + logging.warning("Edge attributes should be re-computed for otf edge indices.") + + if self.otf_edge_index == False: + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + + if self.otf_node_attr == True: + data.x = node_rep_one_hot(data.z).float() + + #initialize variables + atom_fea = data.x + edge_index = data.edge_index + distances = data.edge_attr + #embed atom features + atom_fea = self.embedding(atom_fea) + #convolutional layers + for conv_func in self.convs: + atom_fea, distances = conv_func(atom_fea, edge_index, distances) + + # Post-GNN dense layers + if self.prediction_level == "graph": + if self.pool_order == "early": + crys_fea = getattr(torch_geometric.nn, self.pool)(atom_fea, data.batch) + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + elif self.pool_order == "late": + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + out = getattr(torch_geometric.nn, self.pool)(out, data.batch) + + elif self.prediction_level == "node": + crys_fea = getattr(torch_geometric.nn, self.pool)(atom_fea, data.batch) + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + + return out + + @property + def target_attr(self): + return "y" + diff --git a/matdeeplearn/models/crystal_graph_multi.py b/matdeeplearn/models/crystal_graph_multi.py new file mode 100644 index 00000000..e60bd5ec --- /dev/null +++ b/matdeeplearn/models/crystal_graph_multi.py @@ -0,0 +1,271 @@ +from __future__ import print_function, division + +import torch +import numpy as np +import torch.nn as nn +from matdeeplearn.models.base_model import BaseModel, conditional_grad +from matdeeplearn.common.registry import registry +from torch_scatter import scatter, segment_coo +from torch_geometric.nn import ( + global_mean_pool, + MessagePassing, +) +import torch.nn.functional as F +import torch_geometric + +import warnings + +warnings.filterwarnings("ignore") + +#in comparison to my2, here we use softmax rather than sigmoid +#we also have the choice to use max pool rather than mean pool +class ConvLayer(MessagePassing): + """ + Convolutional operation on graphs + """ + def __init__(self, atom_fea_len, nbr_fea_len, k=3): + """ + Initialize ConvLayer. + + Parameters + ---------- + + atom_fea_len: int + Number of atom hidden features. + nbr_fea_len: int + Number of bond features. + """ + super(ConvLayer, self).__init__(aggr="add", node_dim=0) + self.atom_fea_len = atom_fea_len + self.nbr_fea_len = nbr_fea_len + self.k = k + self.fc_fulls=nn.ModuleList([nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, + 2*self.atom_fea_len+self.nbr_fea_len) for i in range(k)]) + self.fc_fs=nn.ModuleList([nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, + self.atom_fea_len) for i in range(k)]) + self.fc_ss=nn.ModuleList([nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, + self.atom_fea_len) for i in range(k)]) + self.softmax= nn.Softmax(dim=1) + self.softmax2= nn.Softmax(dim=2) + self.softmax3= nn.Softmax(dim=2) + self.softplus1 = nn.ReLU() + self.bn1s = nn.ModuleList([nn.BatchNorm1d(2*self.atom_fea_len+self.nbr_fea_len) for i in range(k)]) + self.bn2s = nn.ModuleList([nn.BatchNorm1d(self.atom_fea_len) for i in range(k)]) + self.atom_fc = nn.Linear(self.k , 2*self.k) + self.nbr_fc = nn.Linear(self.k , 2*self.k) + self.dropout = nn.Dropout() + + def forward(self, x, edge_index, distances): + self.new_attr = [] + self.outs = [] + self.currInd = 0 + for i in range(self.k): + self.outs.append(x) + self.new_attr.append(distances) + self.edge_attr = distances + for i in range(self.k): + aggregatedMessages = self.propagate(edge_index, x=x, distances=distances) + aggregatedMessages = self.bn2s[i](aggregatedMessages) + self.outs[i] = aggregatedMessages + x + out=torch.stack(self.outs, dim=2) + new_nbr=torch.stack(self.new_attr, dim=2) + + out_gated=self.atom_fc(out) + new_nbr_gated=self.nbr_fc(new_nbr) + + out_core, out_filter = out_gated.split([self.k, self.k], dim=2) + new_nbr_core, new_nbr_filter = new_nbr_gated.split([self.k, self.k], dim=2) + out_filter=self.softmax2(out_filter) + new_nbr_filter=self.softmax3(new_nbr_filter) + out = torch.sum(out_core * out_filter, dim=2) + new_nbr = torch.sum(new_nbr_core* new_nbr_filter, dim=2) + return out, new_nbr + + + + def message(self, x_i, x_j, distances): + z = torch.cat([x_i, x_j, distances], dim=-1) + total_gated_fea = self.fc_fulls[self.currInd](z) + total_gated_fea = self.bn1s[self.currInd](total_gated_fea) + nbr_filter, nbr_core, new_edge_attrs = total_gated_fea.split([self.atom_fea_len,self.atom_fea_len,self.nbr_fea_len], dim=1) + self.new_attr[self.currInd] += new_edge_attrs + self.currInd += 1 + nbr_filter = self.fc_fs[self.currInd-1](z) + nbr_core = self.fc_ss[self.currInd-1](z) + return self.softmax(nbr_filter) * self.softplus1(nbr_core) + + + + + +@registry.register_model("CrystalGraphMulti") +class CrystalGraphConvNet(BaseModel): + """ + Create a crystal graph convolutional neural network for predicting total + material properties. + """ + def __init__(self, node_dim, edge_dim, output_dim, + dim1=128, n_conv=4, dim2=128, n_h=1,k=3, pool="global_mean_pool", + pool_order="early", act="relu", classification=False, **kwargs): + """ + Initialize CrystalGraphConvNet. + + Parameters + ---------- + + node_dim: int + Number of atom features in the input. + edge_dim: int + Number of bond features. + dim1: int + Number of hidden atom features in the convolutional layers + n_conv: int + Number of convolutional layers + dim2: int + Number of hidden features after pooling + n_h: int + Number of hidden layers after pooling + """ + super(CrystalGraphConvNet, self).__init__() + self.classification = classification + self.embedding = nn.Linear(node_dim, dim1) + self.convs = nn.ModuleList([ConvLayer(atom_fea_len=dim1, + nbr_fea_len=edge_dim,k=k) + for _ in range(n_conv)]) + self.conv_to_fc = nn.Linear(dim1, dim2) + self.conv_to_fc_softplus = nn.ReLU() + self.output_softplus= nn.ReLU() + if n_h > 1: + self.fcs = nn.ModuleList([nn.Linear(dim2, dim2) + for _ in range(n_h-1)]) + self.softpluses = nn.ModuleList([nn.ReLU() + for _ in range(n_h-1)]) + if self.classification: + self.fc_out = nn.Linear(dim2, 2) + else: + self.fc_out = nn.Linear(dim2, output_dim) + if self.classification: + self.logsoftmax = nn.LogSoftmax(dim=1) + self.dropout = nn.Dropout() + + self.pool = pool + self.act = act + self.pool_order = pool_order + + def forward(self, data): + + output = {} + out = self._forward(data) + output["output"] = out + + if self.gradient == True and out.requires_grad == True: + if self.gradient_method == "conventional": + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.cell], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + #For calculation of stress, see https://github.com/mir-group/nequip/blob/main/nequip/nn/_grad_output.py + #Originally from: https://github.com/atomistic-machine-learning/schnetpack/issues/165 + elif self.gradient_method == "nequip": + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.displacement], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + + output["pos_grad"] = forces + output["cell_grad"] = stress + else: + output["pos_grad"] = None + output["cell_grad"] = None + + return output + @conditional_grad(torch.enable_grad()) + def _forward(self, data): + """ + Forward pass + + data: graph features + + Parameters + ---------- + + data: + data.x: node features + shape = (N, node_dim) + data.edge_index: list of edges + shape = (2, E) + data.edge_attr: edge attributes (distances) + shape = (E, edge_dim) + data.batch: crystal id for each node + shape = (N, ) + + Returns + ------- + + prediction: graph predictions + shape = (batch_size, ) + """ + if self.otf_edge_index == True: + #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + else: + logging.warning("Edge attributes should be re-computed for otf edge indices.") + + if self.otf_edge_index == False: + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + + if self.otf_node_attr == True: + data.x = node_rep_one_hot(data.z).float() + + atom_fea = data.x + edge_index = data.edge_index + distances = data.edge_attr + atom_fea = self.embedding(atom_fea) + for conv_func in self.convs: + atom_fea, distances = conv_func(atom_fea, edge_index, distances) + + if self.prediction_level == "graph": + if self.pool_order == "early": + crys_fea = getattr(torch_geometric.nn, self.pool)(atom_fea, data.batch) + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + elif self.pool_order == "late": + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + out = getattr(torch_geometric.nn, self.pool)(out, data.batch) + + elif self.prediction_level == "node": + crys_fea = getattr(torch_geometric.nn, self.pool)(atom_fea, data.batch) + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + + return out + @property + def target_attr(self): + return "y" + From 7f2b753b84c0b4443ba0354707789ff30271c11a Mon Sep 17 00:00:00 2001 From: Rithwik Seth Date: Mon, 13 Nov 2023 10:32:50 -0800 Subject: [PATCH 2/3] Fixing performance and config --- configs/crystalGraphConfig.yml | 24 +++++++++++----------- matdeeplearn/models/crystal_graph.py | 5 +---- matdeeplearn/models/crystal_graph_multi.py | 12 ++++------- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/configs/crystalGraphConfig.yml b/configs/crystalGraphConfig.yml index 1b36e954..e053a51f 100644 --- a/configs/crystalGraphConfig.yml +++ b/configs/crystalGraphConfig.yml @@ -26,12 +26,12 @@ task: use_amp: True model: - name: CrystalGraphMulti + name: CrystalGraph # model attributes dim1: 128 dim2: 128 n_conv: 4 - n_h: 3 + n_h: 1 k: 3 pool: "global_mean_pool" pool_order: "early" @@ -50,7 +50,7 @@ model: gradient: False optim: - max_epochs: 40 + max_epochs: 300 max_checkpoint_epochs: 0 lr: 0.001 # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper @@ -59,7 +59,7 @@ optim: loss_args: {loss_fn: l1_loss} # gradient clipping value clip_grad_norm: 10 - batch_size: 100 + batch_size: 128 optimizer: optimizer_type: AdamW optimizer_args: {} @@ -84,17 +84,17 @@ dataset: pt_path: data/ # Either "node" or "graph" level prediction_level: graph - + transforms: - name: GetY args: # index specifies the index of a target vector to predict, which is useful when there are multiple property labels for a single dataset # For example, an index: 0 (default) will use the first entry in the target vector # if all values are to be predicted simultaneously, then specify index: -1 - index: -1 + index: 0 otf_transform: True # Optional parameter, default is True # Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html) - data_format: json + data_format: json # specify if additional attributes to be loaded into the dataset from the .json file; e.g. additional_attributes: [forces, stress] additional_attributes: # Print out processing info @@ -103,7 +103,7 @@ dataset: # graph specific settings preprocess_params: # one of mdl (minimum image convention), ocp (all neighbors included) - edge_calc_method: mdl + edge_calc_method: ocp # determine if edges are computed, if false, then they need to be computed on the fly preprocess_edges: True # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly @@ -113,13 +113,13 @@ dataset: # distance cutoff to determine if two atoms are connected by an edge cutoff_radius : 8.0 # maximum number of neighbors to consider (usually an arbitrarily high number to consider all neighbors) - n_neighbors : 250 + n_neighbors : 12 # number of pbc offsets to consider when determining neighbors (usually not changed) num_offsets: 2 # dimension of node attributes node_dim : 100 # dimension of edge attributes - edge_dim : 50 + edge_dim : 64 # whether or not to add self-loops self_loop: True # Method of obtaining atom dictionary: available: (onehot) @@ -130,5 +130,5 @@ dataset: dataset_device: cpu # Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data) train_ratio: 0.8 - val_ratio: 0.05 - test_ratio: 0.15 + val_ratio: 0.1 + test_ratio: 0.1 diff --git a/matdeeplearn/models/crystal_graph.py b/matdeeplearn/models/crystal_graph.py index 93e78f26..ecea846c 100644 --- a/matdeeplearn/models/crystal_graph.py +++ b/matdeeplearn/models/crystal_graph.py @@ -66,10 +66,7 @@ def message(self, x_i, x_j, distances): nbr_filter, nbr_core, new_edge_attrs = total_gated_fea.split([self.atom_fea_len,self.atom_fea_len,self.edge_fea_len], dim=1) #aggregate and return self.edge_attrs += new_edge_attrs - nbr_filter = self.fc_f(z) - nbr_core = self.fc_s(z) - return self.softmax(nbr_filter) * self.softplus1(nbr_core) - + return self.softmax(self.fc_f(z)) * self.softplus1(self.fc_s(z)) diff --git a/matdeeplearn/models/crystal_graph_multi.py b/matdeeplearn/models/crystal_graph_multi.py index e60bd5ec..7bf4c772 100644 --- a/matdeeplearn/models/crystal_graph_multi.py +++ b/matdeeplearn/models/crystal_graph_multi.py @@ -62,7 +62,6 @@ def forward(self, x, edge_index, distances): for i in range(self.k): self.outs.append(x) self.new_attr.append(distances) - self.edge_attr = distances for i in range(self.k): aggregatedMessages = self.propagate(edge_index, x=x, distances=distances) aggregatedMessages = self.bn2s[i](aggregatedMessages) @@ -90,10 +89,7 @@ def message(self, x_i, x_j, distances): nbr_filter, nbr_core, new_edge_attrs = total_gated_fea.split([self.atom_fea_len,self.atom_fea_len,self.nbr_fea_len], dim=1) self.new_attr[self.currInd] += new_edge_attrs self.currInd += 1 - nbr_filter = self.fc_fs[self.currInd-1](z) - nbr_core = self.fc_ss[self.currInd-1](z) - return self.softmax(nbr_filter) * self.softplus1(nbr_core) - + return self.softmax(self.fc_fs[self.currInd-1](z)) * self.softplus1(self.fc_ss[self.currInd-1](z)) @@ -216,7 +212,6 @@ def _forward(self, data): shape = (batch_size, ) """ if self.otf_edge_index == True: - #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) if self.otf_edge_attr == True: data.edge_attr = self.distance_expansion(data.edge_weight) @@ -230,6 +225,7 @@ def _forward(self, data): if self.otf_node_attr == True: data.x = node_rep_one_hot(data.z).float() + atom_fea = data.x edge_index = data.edge_index distances = data.edge_attr @@ -263,8 +259,8 @@ def _forward(self, data): for fc in self.fcs: crys_fea = getattr(F, self.act)(fc(crys_fea)) out = self.fc_out(crys_fea) - - return out + return out + @property def target_attr(self): return "y" From a9152738c0636552ddcf6a4bd5c4afa6e7d07e47 Mon Sep 17 00:00:00 2001 From: Rithwik Seth Date: Mon, 13 Nov 2023 11:41:46 -0800 Subject: [PATCH 3/3] Increasing num neighbors --- configs/crystalGraphConfig.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/crystalGraphConfig.yml b/configs/crystalGraphConfig.yml index e053a51f..171417e3 100644 --- a/configs/crystalGraphConfig.yml +++ b/configs/crystalGraphConfig.yml @@ -70,7 +70,7 @@ optim: verbosity: 5 # tdqm progress bar per batch in the epoch batch_tqdm: False - + dataset: name: test_data # Whether the data has already been processed and a data.pt file is present from a previous run @@ -113,7 +113,7 @@ dataset: # distance cutoff to determine if two atoms are connected by an edge cutoff_radius : 8.0 # maximum number of neighbors to consider (usually an arbitrarily high number to consider all neighbors) - n_neighbors : 12 + n_neighbors : 250 # number of pbc offsets to consider when determining neighbors (usually not changed) num_offsets: 2 # dimension of node attributes