Skip to content

Commit

Permalink
Merge pull request #139 from usnistgov/develop
Browse files Browse the repository at this point in the history
Use torch deterministic algorithm.
  • Loading branch information
knc6 authored Jan 24, 2024
2 parents 339db91 + 46637d3 commit dcd73d1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion alignn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Version number."""
__version__ = "2024.1.4"
__version__ = "2024.1.14"
23 changes: 16 additions & 7 deletions alignn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,22 @@ def train_dgl(
"alignn_cgcnn": ACGCNN,
"alignn_layernorm": ALIGNN_LN,
}
if config.random_seed is not None:
random.seed(config.random_seed)
torch.manual_seed(config.random_seed)
np.random.seed(config.random_seed)
torch.cuda.manual_seed_all(config.random_seed)
try:
import torch_xla.core.xla_model as xm

xm.set_rng_state(config.random_seed)
except ImportError:
pass
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(config.random_seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = str(":4096:8")
torch.use_deterministic_algorithms(True)
if model is None:
net = _model.get(config.model.name)(config.model)
else:
Expand All @@ -277,7 +293,6 @@ def train_dgl(
# group parameters to skip weight decay for bias and batchnorm
params = group_decay(net)
optimizer = setup_optimizer(params, config)

if config.scheduler == "none":
# always return multiplier of 1 (i.e. do nothing)
scheduler = torch.optim.lr_scheduler.LambdaLR(
Expand All @@ -302,12 +317,6 @@ def train_dgl(
)

if config.model.name == "alignn_atomwise":
if config.random_seed is not None:
random.seed(config.random_seed)
np.random.seed(config.random_seed)
torch.manual_seed(config.random_seed)
torch.cuda.manual_seed_all(config.random_seed)
torch.backends.cudnn.deterministic = True

def get_batch_errors(dat=[]):
"""Get errors for samples."""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name="alignn",
version="2024.1.4",
version="2024.1.14",
author="Kamal Choudhary, Brian DeCost",
author_email="[email protected]",
description="alignn",
Expand Down

0 comments on commit dcd73d1

Please sign in to comment.