Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
f3f89ba
Also update data.edge_vecs
itsgt Oct 18, 2023
f12ad99
Merge pull request #50 from Fung-Lab/otf_update_edge_vecs
vxfung Oct 18, 2023
d05b713
Added crystal graph
Nov 12, 2023
7f2b753
Fixing performance and config
Nov 13, 2023
a915273
Increasing num neighbors
Nov 13, 2023
76d988f
Remove dependence of calculator on trainer
Nov 20, 2023
b14714c
Update optimization
Nov 25, 2023
7cc2bc0
model_ensemble
orivera2280 Jan 2, 2024
cf088bb
model ensemble
orivera2280 Jan 2, 2024
1b4c0fe
model ensemble
orivera2280 Jan 2, 2024
5f102b3
model ensemble
orivera2280 Jan 2, 2024
db55682
updating config.yml
orivera2280 Jan 2, 2024
1c3a9ad
model ensemble fixes
orivera2280 Jan 2, 2024
a3e4758
model ensemble fixes 2
orivera2280 Jan 2, 2024
79ddc16
fixing config
orivera2280 Jan 2, 2024
de471a2
fixing indentation
orivera2280 Jan 2, 2024
db0a751
fixing indentation
orivera2280 Jan 2, 2024
64ad1e8
fixing bug in _forward
orivera2280 Jan 3, 2024
02ef025
fixing bug in property trainer _forward
orivera2280 Jan 3, 2024
965b07a
fixing problem with save_model in base_trainer
orivera2280 Jan 3, 2024
ac5a587
fixing problem with save_model in base_trainer
orivera2280 Jan 3, 2024
e991348
std key for prediction dictionary
orivera2280 Jan 5, 2024
04249a4
std key for prediction dictionary
orivera2280 Jan 5, 2024
48d7940
fixing bugs with std and predict method
orivera2280 Jan 5, 2024
247feb3
fixing type error
orivera2280 Jan 5, 2024
0503ffb
various bug fixes and improving readability
orivera2280 Jan 7, 2024
08894ab
Merge pull request #55 from orivera2280/main
vxfung Jan 7, 2024
2e3b454
Merge branch 'main' into crystalGraph
vxfung Jan 7, 2024
1743094
Merge pull request #53 from Fung-Lab/crystalGraph
vxfung Jan 7, 2024
5f089df
config change and result save bugfix
vxfung Jan 7, 2024
133fc40
result save fix2
vxfung Jan 8, 2024
a10cbe6
Fixed merging conficts
qzheng75 Jan 9, 2024
394f9b1
Fixed merging conficts
qzheng75 Jan 9, 2024
d77e132
Support ensemble model calculation.
qzheng75 Jan 9, 2024
1469e2a
Remove predict_by_calculator
qzheng75 Jan 9, 2024
c3d9b2b
Remove extra files
qzheng75 Jan 9, 2024
62d58f9
Instantiate natoms with device
itsgt Jan 17, 2024
0182638
Merge pull request #59 from Fung-Lab/ocp_device_fix
vxfung Jan 17, 2024
ac75463
Merge branch 'main' into Calculator_test
vxfung Jan 17, 2024
6e60d2b
Merge pull request #57 from Fung-Lab/Calculator_test
vxfung Jan 17, 2024
126d4c7
Update ase_utils.py
vxfung Jan 22, 2024
6f218f9
Merge pull request #60 from Fung-Lab/vxfung-patch-3
qzheng75 Jan 22, 2024
cc25a20
Use torch_geometric.utils for scatter
itsgt Jan 22, 2024
26fc3af
Merge pull request #61 from Fung-Lab/torchmd_fix_scatter
vxfung Jan 22, 2024
7124294
Allow None for dataset and data_loader
itsgt Jan 22, 2024
4266f2c
Allow _load_model to have dataset = None
itsgt Jan 22, 2024
1b86060
Allow sampler to be none too
itsgt Jan 22, 2024
470c825
Update config.yml
vxfung Jan 23, 2024
2bd2281
Fix conditional
itsgt Jan 24, 2024
988c1c4
Fix another conditional
itsgt Jan 24, 2024
22b3580
'| None' appears to be unnecessary and breaking
itsgt Jan 24, 2024
8c5e081
Use dataset unless it is None
itsgt Jan 24, 2024
f5faffd
Fix typo
itsgt Jan 24, 2024
24d6970
Another priority of dataset over config
itsgt Jan 24, 2024
71538d0
Merge branch 'main' into no_data_option
itsgt Jan 25, 2024
e992766
Keep original node_dim determination
itsgt Jan 25, 2024
3cb3695
Merge pull request #62 from Fung-Lab/no_data_option
vxfung Jan 25, 2024
91e7259
Update test_predict.yml
vxfung Jan 26, 2024
b71996c
Update test_training.yml
vxfung Jan 26, 2024
665b18e
Update base_trainer.py
vxfung Jan 26, 2024
a650461
Update utils.py
vxfung Jan 26, 2024
d07afb0
Merge pull request #65 from Fung-Lab/vxfung-patch-4
vxfung Jan 26, 2024
0b75fa1
Update helpers.py
vxfung Jan 26, 2024
a801342
Merge pull request #67 from Fung-Lab/vxfung-patch-1
vxfung Jan 26, 2024
505b73f
Update base_trainer.py
vxfung Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions configs/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ model:
otf_edge_attr: False
# Compute node attributes on the fly in the model forward
otf_node_attr: False
# 1 indicates normal behavior, larger numbers indicate the number of models to be used
model_ensemble: 1
# compute gradients w.r.t to positions and cell, requires otf_edge_attr=True
gradient: False

optim:
max_epochs: 40
max_epochs: 200
max_checkpoint_epochs: 0
lr: 0.001
# Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper
Expand Down Expand Up @@ -130,4 +132,4 @@ dataset:
# Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data)
train_ratio: 0.8
val_ratio: 0.05
test_ratio: 0.15
test_ratio: 0.15
88 changes: 6 additions & 82 deletions configs/config_calculator.yml
Original file line number Diff line number Diff line change
@@ -1,26 +1,9 @@
trainer: matdeeplearn.trainers.PropertyTrainer

task:
run_mode: train
identifier: my_train_job
parallel: False
# If seed is not set, then it will be random every time
seed: 12345678
# Defaults to run directory if not specified
save_dir:
# continue from a previous job
continue_job: False
# spefcify if the training state is loaded: epochs, learning rate, etc
load_training_state: False
# Path to the checkpoint.pt file. The model used in the calculator will load parameters from this file.
checkpoint_path: results/2023-09-20-16-22-38-738-my_train_job/checkpoint/best_checkpoint.pt
# E.g. ["train", "val", "test"]
write_output: [train, val, test]
# Specify if labels are provided for the predict task
# labels: True
use_amp: True
checkpoint_path: ./checkpoints/cgcnn_checkpoint.pt

model:
# Model used by the calculator
name: CGCNN
# model attributes
dim1: 100
Expand All @@ -39,62 +22,12 @@ model:
# Compute edge attributes on the fly in the model forward
otf_edge_attr: True
# Compute node attributes on the fly in the model forward
otf_node_attr: True
otf_node_attr: False
model_ensemble: 1
# compute gradients w.r.t to positions and cell, requires otf_edge_attr=True
gradient: True

optim:
max_epochs: 40
max_checkpoint_epochs: 0
lr: 0.002
# Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper
loss:
loss_type: TorchLossWrapper
loss_args: {loss_fn: l1_loss}
# gradient clipping value
clip_grad_norm: 10
batch_size: 100
optimizer:
optimizer_type: AdamW
optimizer_args: {}
scheduler:
scheduler_type: ReduceLROnPlateau
scheduler_args: {mode: min, factor: 0.8, patience: 10, min_lr: 0.00001, threshold: 0.0002}
#Training print out frequency (print per n number of epochs)
verbosity: 5
# tdqm progress bar per batch in the epoch
batch_tqdm: False

dataset:
name: test_data
# Whether the data has already been processed and a data.pt file is present from a previous run
processed: False
# Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path}
src: data/force_data/data.json
# Path to target file within data_path - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path} or left blank when the dataset is a single json file
# Example: target_path: "data/raw_graph_scalar/targets.csv"
target_path:
# Path to save processed data.pt file
pt_path: data/force_data/
# Either "node" or "graph" level
prediction_level: graph

transforms:
- name: GetY
args:
# index specifies the index of a target vector to predict, which is useful when there are multiple property labels for a single dataset
# For example, an index: 0 (default) will use the first entry in the target vector
# if all values are to be predicted simultaneously, then specify index: -1
index: -1
otf_transform: True # Optional parameter, default is True
# Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html)
data_format: json
# specify if additional attributes to be loaded into the dataset from the .json file; e.g. additional_attributes: [forces, stress]
additional_attributes:
# Print out processing info
verbose: True
# Index of target column in targets.csv
# graph specific settings
dataset:
preprocess_params:
# one of mdl (minimum image convention), ocp (all neighbors included)
edge_calc_method: ocp
Expand All @@ -118,13 +51,4 @@ dataset:
self_loop: True
# Method of obtaining atom dictionary: available: (onehot)
node_representation: onehot
all_neighbors: True

# Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html
num_workers: 0
# Where the dataset is loaded; either "cpu" or "cuda"
dataset_device: cpu
# Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data)
train_ratio: 0.9
val_ratio: 0.05
test_ratio: 0.05
all_neighbors: True
134 changes: 134 additions & 0 deletions configs/crystalGraphConfig.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
trainer: property

task:
#run_mode: train
identifier: my_train_job
parallel: False
# If seed is not set, then it will be random every time
seed: 12345678
# Defaults to run directory if not specified
save_dir:
# continue from a previous job
continue_job: False
# spefcify if the training state is loaded: epochs, learning rate, etc
load_training_state: False
# Path to the checkpoint.pt file
checkpoint_path:
# Whether to write predictions to csv file. E.g. ["train", "val", "test"]
write_output: [train, val, test]
# Frequency of writing to file; 0 denotes writing only at the end, 1 denotes writing every time
output_frequency: 0
# Frequency of saving model .pt file; 0 denotes saving only at the end, 1 denotes saving every time, -1 denotes never saving; this controls both checkpoint and best_checkpoint
model_save_frequency: 0
# Specify if labels are provided for the predict task
# labels: True
# Use amp mixed precision
use_amp: True

model:
name: CrystalGraph
# model attributes
dim1: 128
dim2: 128
n_conv: 4
n_h: 1
k: 3
pool: "global_mean_pool"
pool_order: "early"
batch_norm: True
batch_track_stats: True
act: "relu"
dropout_rate: 0.0
classification: False
# Compute edge indices on the fly in the model forward
otf_edge_index: False
# Compute edge attributes on the fly in the model forward
otf_edge_attr: False
# Compute node attributes on the fly in the model forward
otf_node_attr: False
# compute gradients w.r.t to positions and cell, requires otf_edge_attr=True
gradient: False

optim:
max_epochs: 300
max_checkpoint_epochs: 0
lr: 0.001
# Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper
loss:
loss_type: TorchLossWrapper
loss_args: {loss_fn: l1_loss}
# gradient clipping value
clip_grad_norm: 10
batch_size: 128
optimizer:
optimizer_type: AdamW
optimizer_args: {}
scheduler:
scheduler_type: ReduceLROnPlateau
scheduler_args: {mode: min, factor: 0.8, patience: 10, min_lr: 0.00001, threshold: 0.0002}
#Training print out frequency (print per n number of epochs)
verbosity: 5
# tdqm progress bar per batch in the epoch
batch_tqdm: False

dataset:
name: test_data
# Whether the data has already been processed and a data.pt file is present from a previous run
processed: False
# Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path}
src: data/test_data/data_graph_scalar.json
# Path to target file within data_path - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path} or left blank when the dataset is a single json file
# Example: target_path: "data/raw_graph_scalar/targets.csv"
target_path:
# Path to save processed data.pt file
pt_path: data/
# Either "node" or "graph" level
prediction_level: graph

transforms:
- name: GetY
args:
# index specifies the index of a target vector to predict, which is useful when there are multiple property labels for a single dataset
# For example, an index: 0 (default) will use the first entry in the target vector
# if all values are to be predicted simultaneously, then specify index: -1
index: 0
otf_transform: True # Optional parameter, default is True
# Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html)
data_format: json
# specify if additional attributes to be loaded into the dataset from the .json file; e.g. additional_attributes: [forces, stress]
additional_attributes:
# Print out processing info
verbose: True
# Index of target column in targets.csv
# graph specific settings
preprocess_params:
# one of mdl (minimum image convention), ocp (all neighbors included)
edge_calc_method: ocp
# determine if edges are computed, if false, then they need to be computed on the fly
preprocess_edges: True
# determine if edge attributes are computed during processing, if false, then they need to be computed on the fly
preprocess_edge_features: True
# determine if node attributes are computed during processing, if false, then they need to be computed on the fly
preprocess_node_features: True
# distance cutoff to determine if two atoms are connected by an edge
cutoff_radius : 8.0
# maximum number of neighbors to consider (usually an arbitrarily high number to consider all neighbors)
n_neighbors : 250
# number of pbc offsets to consider when determining neighbors (usually not changed)
num_offsets: 2
# dimension of node attributes
node_dim : 100
# dimension of edge attributes
edge_dim : 64
# whether or not to add self-loops
self_loop: True
# Method of obtaining atom dictionary: available: (onehot)
node_representation: onehot
# Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html
num_workers: 0
# Where the dataset is loaded; either "cpu" or "cuda"
dataset_device: cpu
# Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data)
train_ratio: 0.8
val_ratio: 0.1
test_ratio: 0.1
Binary file added data/data.pt
Binary file not shown.
84 changes: 84 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
import pprint
import os
import sys
import shutil
from datetime import datetime
from torch import distributed as dist
from matdeeplearn.common.config.build_config import build_config
from matdeeplearn.common.config.flags import flags
from matdeeplearn.common.trainer_context import new_trainer_context
from matdeeplearn.preprocessor.processor import process_data

# import submitit

# from matdeeplearn.common.utils import setup_logging


class Runner: # submitit.helpers.Checkpointable):
def __init__(self):
self.config = None

def __call__(self, config):

with new_trainer_context(args=args, config=config) as ctx:
self.config = ctx.config
self.task = ctx.task
self.trainer = ctx.trainer

self.task.setup(self.trainer)

# Print settings for job
logging.debug("Settings: ")
logging.debug(pprint.pformat(self.config))

self.task.run()

shutil.move('log_'+config["task"]["log_id"]+'.txt', os.path.join(self.trainer.save_dir, "results", self.trainer.timestamp_id, "log.txt"))

def checkpoint(self, *args, **kwargs):
# new_runner = Runner()
self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
self.config["checkpoint"] = self.task.chkpt_path
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
# return submitit.helpers.DelayedSubmission(new_runner, self.config)


if __name__ == "__main__":


# setup_logging()
local_rank = os.environ.get('LOCAL_RANK', None)
if local_rank == None or int(local_rank) == 0:
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)

timestamp = datetime.now().timestamp()
timestamp_id = datetime.fromtimestamp(timestamp).strftime(
"%Y-%m-%d-%H-%M-%S-%f"
)[:-3]
fh = logging.FileHandler('log_'+timestamp_id+'.txt', 'w+')
fh.setLevel(logging.DEBUG)
root_logger.addHandler(fh)

sh = logging.StreamHandler(sys.stdout)
sh.setLevel(logging.DEBUG)
root_logger.addHandler(sh)

parser = flags.get_parser()
args, override_args = parser.parse_known_args()
config = build_config(args, override_args)
config["task"]["log_id"] = timestamp_id

if not config["dataset"]["processed"]:
process_data(config["dataset"])

if args.submit: # Run on cluster
# TODO: add setup to submit to cluster
pass

else: # Run locally
Runner()(config)

Loading