diff --git a/configs/crystalGraphConfig.yml b/configs/crystalGraphConfig.yml new file mode 100644 index 00000000..171417e3 --- /dev/null +++ b/configs/crystalGraphConfig.yml @@ -0,0 +1,134 @@ +trainer: property + +task: + #run_mode: train + identifier: my_train_job + parallel: False + # If seed is not set, then it will be random every time + seed: 12345678 + # Defaults to run directory if not specified + save_dir: + # continue from a previous job + continue_job: False + # spefcify if the training state is loaded: epochs, learning rate, etc + load_training_state: False + # Path to the checkpoint.pt file + checkpoint_path: + # Whether to write predictions to csv file. E.g. ["train", "val", "test"] + write_output: [train, val, test] + # Frequency of writing to file; 0 denotes writing only at the end, 1 denotes writing every time + output_frequency: 0 + # Frequency of saving model .pt file; 0 denotes saving only at the end, 1 denotes saving every time, -1 denotes never saving; this controls both checkpoint and best_checkpoint + model_save_frequency: 0 + # Specify if labels are provided for the predict task + # labels: True + # Use amp mixed precision + use_amp: True + +model: + name: CrystalGraph + # model attributes + dim1: 128 + dim2: 128 + n_conv: 4 + n_h: 1 + k: 3 + pool: "global_mean_pool" + pool_order: "early" + batch_norm: True + batch_track_stats: True + act: "relu" + dropout_rate: 0.0 + classification: False + # Compute edge indices on the fly in the model forward + otf_edge_index: False + # Compute edge attributes on the fly in the model forward + otf_edge_attr: False + # Compute node attributes on the fly in the model forward + otf_node_attr: False + # compute gradients w.r.t to positions and cell, requires otf_edge_attr=True + gradient: False + +optim: + max_epochs: 300 + max_checkpoint_epochs: 0 + lr: 0.001 + # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper + loss: + loss_type: TorchLossWrapper + loss_args: {loss_fn: l1_loss} + # gradient clipping value + clip_grad_norm: 10 + batch_size: 128 + optimizer: + optimizer_type: AdamW + optimizer_args: {} + scheduler: + scheduler_type: ReduceLROnPlateau + scheduler_args: {mode: min, factor: 0.8, patience: 10, min_lr: 0.00001, threshold: 0.0002} + #Training print out frequency (print per n number of epochs) + verbosity: 5 + # tdqm progress bar per batch in the epoch + batch_tqdm: False + +dataset: + name: test_data + # Whether the data has already been processed and a data.pt file is present from a previous run + processed: False + # Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path} + src: data/test_data/data_graph_scalar.json + # Path to target file within data_path - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path} or left blank when the dataset is a single json file + # Example: target_path: "data/raw_graph_scalar/targets.csv" + target_path: + # Path to save processed data.pt file + pt_path: data/ + # Either "node" or "graph" level + prediction_level: graph + + transforms: + - name: GetY + args: + # index specifies the index of a target vector to predict, which is useful when there are multiple property labels for a single dataset + # For example, an index: 0 (default) will use the first entry in the target vector + # if all values are to be predicted simultaneously, then specify index: -1 + index: 0 + otf_transform: True # Optional parameter, default is True + # Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html) + data_format: json + # specify if additional attributes to be loaded into the dataset from the .json file; e.g. additional_attributes: [forces, stress] + additional_attributes: + # Print out processing info + verbose: True + # Index of target column in targets.csv + # graph specific settings + preprocess_params: + # one of mdl (minimum image convention), ocp (all neighbors included) + edge_calc_method: ocp + # determine if edges are computed, if false, then they need to be computed on the fly + preprocess_edges: True + # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly + preprocess_edge_features: True + # determine if node attributes are computed during processing, if false, then they need to be computed on the fly + preprocess_node_features: True + # distance cutoff to determine if two atoms are connected by an edge + cutoff_radius : 8.0 + # maximum number of neighbors to consider (usually an arbitrarily high number to consider all neighbors) + n_neighbors : 250 + # number of pbc offsets to consider when determining neighbors (usually not changed) + num_offsets: 2 + # dimension of node attributes + node_dim : 100 + # dimension of edge attributes + edge_dim : 64 + # whether or not to add self-loops + self_loop: True + # Method of obtaining atom dictionary: available: (onehot) + node_representation: onehot + # Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html + num_workers: 0 + # Where the dataset is loaded; either "cpu" or "cuda" + dataset_device: cpu + # Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data) + train_ratio: 0.8 + val_ratio: 0.1 + test_ratio: 0.1 diff --git a/matdeeplearn/models/crystal_graph.py b/matdeeplearn/models/crystal_graph.py new file mode 100644 index 00000000..ecea846c --- /dev/null +++ b/matdeeplearn/models/crystal_graph.py @@ -0,0 +1,247 @@ +from __future__ import print_function, division + +import torch +import numpy as np +import torch.nn as nn +from matdeeplearn.models.base_model import BaseModel, conditional_grad +from matdeeplearn.common.registry import registry +from torch_scatter import scatter, segment_coo +from torch_geometric.nn import ( + global_mean_pool, + MessagePassing, +) +import torch.nn.functional as F +import torch_geometric +import warnings + +warnings.filterwarnings("ignore") + +#in comparison to my2, here we use softmax rather than sigmoid +#we also have the choice to use max pool rather than mean pool +class ConvLayer(MessagePassing): + """ + Convolutional operation on graphs + """ + def __init__(self, atom_fea_len, edge_fea_len): + """ + Initialize ConvLayer. + + Parameters + ---------- + + atom_fea_len: int + Number of atom hidden features. + nbr_fea_len: int + Number of bond features. + """ + super(ConvLayer, self).__init__(aggr="add", node_dim=0) + self.atom_fea_len = atom_fea_len + self.edge_fea_len = edge_fea_len + self.fc_full=nn.Linear(2*self.atom_fea_len+self.edge_fea_len, + 2*self.atom_fea_len+self.edge_fea_len) + self.fc_f = nn.Linear(2*self.atom_fea_len+self.edge_fea_len, self.atom_fea_len) + self.fc_s = nn.Linear(2*self.atom_fea_len+self.edge_fea_len, self.atom_fea_len) + self.softmax= nn.Softmax(dim=1) + self.softplus1 = nn.ReLU() + self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len+self.edge_fea_len) + self.bn2 = nn.BatchNorm1d(self.atom_fea_len) + self.dropout = nn.Dropout() + + def forward(self, x, edge_index, distances): + self.edge_attrs = distances + aggregatedMessages = self.propagate(edge_index, x=x, distances=distances) + aggregatedMessages = self.bn2(aggregatedMessages) + out = aggregatedMessages + x + return out, self.edge_attrs + + + + def message(self, x_i, x_j, distances): + #concatenate atom features, bond features, and bond distances + z = torch.cat([x_i, x_j, distances], dim=-1) + #fully connected layer + total_gated_fea = self.fc_full(z) + total_gated_fea = self.bn1(total_gated_fea) + #split into atom features, bond features, and bond distances and apply functions + nbr_filter, nbr_core, new_edge_attrs = total_gated_fea.split([self.atom_fea_len,self.atom_fea_len,self.edge_fea_len], dim=1) + #aggregate and return + self.edge_attrs += new_edge_attrs + return self.softmax(self.fc_f(z)) * self.softplus1(self.fc_s(z)) + + + + +@registry.register_model("CrystalGraph") +class CrystalGraphConvNet(BaseModel): + """ + Create a crystal graph convolutional neural network for predicting total + material properties. + """ + def __init__(self, node_dim, edge_dim, output_dim, + dim1=128, n_conv=4, dim2=128, n_h=1, pool="global_mean_pool", + pool_order="early", act="relu", classification=False, **kwargs): + """ + Initialize CrystalGraphConvNet. + + Parameters + ---------- + + node_dim: int + Number of atom features in the input. + edge_dim: int + Number of bond features. + dim1: int + Number of hidden atom features in the convolutional layers + n_conv: int + Number of convolutional layers + dim2: int + Number of hidden features after pooling + n_h: int + Number of hidden layers after pooling + """ + super(CrystalGraphConvNet, self).__init__() + self.classification = classification + self.embedding = nn.Linear(node_dim, dim1) + self.convs = nn.ModuleList([ConvLayer(atom_fea_len=dim1, + edge_fea_len=edge_dim) + for _ in range(n_conv)]) + self.conv_to_fc = nn.Linear(dim1, dim2) + self.conv_to_fc_softplus = nn.ReLU() + self.output_softplus= nn.ReLU() + if n_h > 1: + self.fcs = nn.ModuleList([nn.Linear(dim2, dim2) + for _ in range(n_h-1)]) + if self.classification: + self.fc_out = nn.Linear(dim2, 2) + else: + self.fc_out = nn.Linear(dim2, output_dim) + if self.classification: + self.logsoftmax = nn.LogSoftmax(dim=1) + self.dropout = nn.Dropout() + + self.pool = pool + self.act = act + self.pool_order = pool_order + + def forward(self, data): + + output = {} + out = self._forward(data) + output["output"] = out + + if self.gradient == True and out.requires_grad == True: + if self.gradient_method == "conventional": + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.cell], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + #For calculation of stress, see https://github.com/mir-group/nequip/blob/main/nequip/nn/_grad_output.py + #Originally from: https://github.com/atomistic-machine-learning/schnetpack/issues/165 + elif self.gradient_method == "nequip": + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.displacement], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + + output["pos_grad"] = forces + output["cell_grad"] = stress + else: + output["pos_grad"] = None + output["cell_grad"] = None + + return output + @conditional_grad(torch.enable_grad()) + def _forward(self, data): + """ + Forward pass + + data: graph features + + Parameters + ---------- + + data: + data.x: node features + shape = (N, node_dim) + data.edge_index: list of edges + shape = (2, E) + data.edge_attr: edge attributes (distances) + shape = (E, edge_dim) + data.batch: crystal id for each node + shape = (N, ) + + Returns + ------- + + prediction: graph predictions + shape = (batch_size, ) + """ + if self.otf_edge_index == True: + #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + else: + logging.warning("Edge attributes should be re-computed for otf edge indices.") + + if self.otf_edge_index == False: + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + + if self.otf_node_attr == True: + data.x = node_rep_one_hot(data.z).float() + + #initialize variables + atom_fea = data.x + edge_index = data.edge_index + distances = data.edge_attr + #embed atom features + atom_fea = self.embedding(atom_fea) + #convolutional layers + for conv_func in self.convs: + atom_fea, distances = conv_func(atom_fea, edge_index, distances) + + # Post-GNN dense layers + if self.prediction_level == "graph": + if self.pool_order == "early": + crys_fea = getattr(torch_geometric.nn, self.pool)(atom_fea, data.batch) + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + elif self.pool_order == "late": + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + out = getattr(torch_geometric.nn, self.pool)(out, data.batch) + + elif self.prediction_level == "node": + crys_fea = getattr(torch_geometric.nn, self.pool)(atom_fea, data.batch) + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + + return out + + @property + def target_attr(self): + return "y" + diff --git a/matdeeplearn/models/crystal_graph_multi.py b/matdeeplearn/models/crystal_graph_multi.py new file mode 100644 index 00000000..7bf4c772 --- /dev/null +++ b/matdeeplearn/models/crystal_graph_multi.py @@ -0,0 +1,267 @@ +from __future__ import print_function, division + +import torch +import numpy as np +import torch.nn as nn +from matdeeplearn.models.base_model import BaseModel, conditional_grad +from matdeeplearn.common.registry import registry +from torch_scatter import scatter, segment_coo +from torch_geometric.nn import ( + global_mean_pool, + MessagePassing, +) +import torch.nn.functional as F +import torch_geometric + +import warnings + +warnings.filterwarnings("ignore") + +#in comparison to my2, here we use softmax rather than sigmoid +#we also have the choice to use max pool rather than mean pool +class ConvLayer(MessagePassing): + """ + Convolutional operation on graphs + """ + def __init__(self, atom_fea_len, nbr_fea_len, k=3): + """ + Initialize ConvLayer. + + Parameters + ---------- + + atom_fea_len: int + Number of atom hidden features. + nbr_fea_len: int + Number of bond features. + """ + super(ConvLayer, self).__init__(aggr="add", node_dim=0) + self.atom_fea_len = atom_fea_len + self.nbr_fea_len = nbr_fea_len + self.k = k + self.fc_fulls=nn.ModuleList([nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, + 2*self.atom_fea_len+self.nbr_fea_len) for i in range(k)]) + self.fc_fs=nn.ModuleList([nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, + self.atom_fea_len) for i in range(k)]) + self.fc_ss=nn.ModuleList([nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, + self.atom_fea_len) for i in range(k)]) + self.softmax= nn.Softmax(dim=1) + self.softmax2= nn.Softmax(dim=2) + self.softmax3= nn.Softmax(dim=2) + self.softplus1 = nn.ReLU() + self.bn1s = nn.ModuleList([nn.BatchNorm1d(2*self.atom_fea_len+self.nbr_fea_len) for i in range(k)]) + self.bn2s = nn.ModuleList([nn.BatchNorm1d(self.atom_fea_len) for i in range(k)]) + self.atom_fc = nn.Linear(self.k , 2*self.k) + self.nbr_fc = nn.Linear(self.k , 2*self.k) + self.dropout = nn.Dropout() + + def forward(self, x, edge_index, distances): + self.new_attr = [] + self.outs = [] + self.currInd = 0 + for i in range(self.k): + self.outs.append(x) + self.new_attr.append(distances) + for i in range(self.k): + aggregatedMessages = self.propagate(edge_index, x=x, distances=distances) + aggregatedMessages = self.bn2s[i](aggregatedMessages) + self.outs[i] = aggregatedMessages + x + out=torch.stack(self.outs, dim=2) + new_nbr=torch.stack(self.new_attr, dim=2) + + out_gated=self.atom_fc(out) + new_nbr_gated=self.nbr_fc(new_nbr) + + out_core, out_filter = out_gated.split([self.k, self.k], dim=2) + new_nbr_core, new_nbr_filter = new_nbr_gated.split([self.k, self.k], dim=2) + out_filter=self.softmax2(out_filter) + new_nbr_filter=self.softmax3(new_nbr_filter) + out = torch.sum(out_core * out_filter, dim=2) + new_nbr = torch.sum(new_nbr_core* new_nbr_filter, dim=2) + return out, new_nbr + + + + def message(self, x_i, x_j, distances): + z = torch.cat([x_i, x_j, distances], dim=-1) + total_gated_fea = self.fc_fulls[self.currInd](z) + total_gated_fea = self.bn1s[self.currInd](total_gated_fea) + nbr_filter, nbr_core, new_edge_attrs = total_gated_fea.split([self.atom_fea_len,self.atom_fea_len,self.nbr_fea_len], dim=1) + self.new_attr[self.currInd] += new_edge_attrs + self.currInd += 1 + return self.softmax(self.fc_fs[self.currInd-1](z)) * self.softplus1(self.fc_ss[self.currInd-1](z)) + + + + +@registry.register_model("CrystalGraphMulti") +class CrystalGraphConvNet(BaseModel): + """ + Create a crystal graph convolutional neural network for predicting total + material properties. + """ + def __init__(self, node_dim, edge_dim, output_dim, + dim1=128, n_conv=4, dim2=128, n_h=1,k=3, pool="global_mean_pool", + pool_order="early", act="relu", classification=False, **kwargs): + """ + Initialize CrystalGraphConvNet. + + Parameters + ---------- + + node_dim: int + Number of atom features in the input. + edge_dim: int + Number of bond features. + dim1: int + Number of hidden atom features in the convolutional layers + n_conv: int + Number of convolutional layers + dim2: int + Number of hidden features after pooling + n_h: int + Number of hidden layers after pooling + """ + super(CrystalGraphConvNet, self).__init__() + self.classification = classification + self.embedding = nn.Linear(node_dim, dim1) + self.convs = nn.ModuleList([ConvLayer(atom_fea_len=dim1, + nbr_fea_len=edge_dim,k=k) + for _ in range(n_conv)]) + self.conv_to_fc = nn.Linear(dim1, dim2) + self.conv_to_fc_softplus = nn.ReLU() + self.output_softplus= nn.ReLU() + if n_h > 1: + self.fcs = nn.ModuleList([nn.Linear(dim2, dim2) + for _ in range(n_h-1)]) + self.softpluses = nn.ModuleList([nn.ReLU() + for _ in range(n_h-1)]) + if self.classification: + self.fc_out = nn.Linear(dim2, 2) + else: + self.fc_out = nn.Linear(dim2, output_dim) + if self.classification: + self.logsoftmax = nn.LogSoftmax(dim=1) + self.dropout = nn.Dropout() + + self.pool = pool + self.act = act + self.pool_order = pool_order + + def forward(self, data): + + output = {} + out = self._forward(data) + output["output"] = out + + if self.gradient == True and out.requires_grad == True: + if self.gradient_method == "conventional": + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.cell], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + #For calculation of stress, see https://github.com/mir-group/nequip/blob/main/nequip/nn/_grad_output.py + #Originally from: https://github.com/atomistic-machine-learning/schnetpack/issues/165 + elif self.gradient_method == "nequip": + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.displacement], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + + output["pos_grad"] = forces + output["cell_grad"] = stress + else: + output["pos_grad"] = None + output["cell_grad"] = None + + return output + @conditional_grad(torch.enable_grad()) + def _forward(self, data): + """ + Forward pass + + data: graph features + + Parameters + ---------- + + data: + data.x: node features + shape = (N, node_dim) + data.edge_index: list of edges + shape = (2, E) + data.edge_attr: edge attributes (distances) + shape = (E, edge_dim) + data.batch: crystal id for each node + shape = (N, ) + + Returns + ------- + + prediction: graph predictions + shape = (batch_size, ) + """ + if self.otf_edge_index == True: + data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + else: + logging.warning("Edge attributes should be re-computed for otf edge indices.") + + if self.otf_edge_index == False: + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + + if self.otf_node_attr == True: + data.x = node_rep_one_hot(data.z).float() + + + atom_fea = data.x + edge_index = data.edge_index + distances = data.edge_attr + atom_fea = self.embedding(atom_fea) + for conv_func in self.convs: + atom_fea, distances = conv_func(atom_fea, edge_index, distances) + + if self.prediction_level == "graph": + if self.pool_order == "early": + crys_fea = getattr(torch_geometric.nn, self.pool)(atom_fea, data.batch) + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + elif self.pool_order == "late": + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + out = getattr(torch_geometric.nn, self.pool)(out, data.batch) + + elif self.prediction_level == "node": + crys_fea = getattr(torch_geometric.nn, self.pool)(atom_fea, data.batch) + crys_fea = self.conv_to_fc(getattr(F, self.act)(crys_fea)) + crys_fea = getattr(F, self.act)(crys_fea) + if hasattr(self, 'fcs'): + for fc in self.fcs: + crys_fea = getattr(F, self.act)(fc(crys_fea)) + out = self.fc_out(crys_fea) + return out + + @property + def target_attr(self): + return "y" +