From 46637d3f0941e66eea3c24ba930926470506f524 Mon Sep 17 00:00:00 2001 From: knc6 Date: Wed, 24 Jan 2024 06:29:27 -0500 Subject: [PATCH] Torch deterministic. --- alignn/__init__.py | 2 +- alignn/train.py | 23 ++++++++++++++++------- setup.py | 2 +- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/alignn/__init__.py b/alignn/__init__.py index e64e962..e094167 100644 --- a/alignn/__init__.py +++ b/alignn/__init__.py @@ -1,2 +1,2 @@ """Version number.""" -__version__ = "2024.1.4" +__version__ = "2024.1.14" diff --git a/alignn/train.py b/alignn/train.py index 7633a70..1b00279 100644 --- a/alignn/train.py +++ b/alignn/train.py @@ -266,6 +266,22 @@ def train_dgl( "alignn_cgcnn": ACGCNN, "alignn_layernorm": ALIGNN_LN, } + if config.random_seed is not None: + random.seed(config.random_seed) + torch.manual_seed(config.random_seed) + np.random.seed(config.random_seed) + torch.cuda.manual_seed_all(config.random_seed) + try: + import torch_xla.core.xla_model as xm + + xm.set_rng_state(config.random_seed) + except ImportError: + pass + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PYTHONHASHSEED"] = str(config.random_seed) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = str(":4096:8") + torch.use_deterministic_algorithms(True) if model is None: net = _model.get(config.model.name)(config.model) else: @@ -277,7 +293,6 @@ def train_dgl( # group parameters to skip weight decay for bias and batchnorm params = group_decay(net) optimizer = setup_optimizer(params, config) - if config.scheduler == "none": # always return multiplier of 1 (i.e. do nothing) scheduler = torch.optim.lr_scheduler.LambdaLR( @@ -302,12 +317,6 @@ def train_dgl( ) if config.model.name == "alignn_atomwise": - if config.random_seed is not None: - random.seed(config.random_seed) - np.random.seed(config.random_seed) - torch.manual_seed(config.random_seed) - torch.cuda.manual_seed_all(config.random_seed) - torch.backends.cudnn.deterministic = True def get_batch_errors(dat=[]): """Get errors for samples.""" diff --git a/setup.py b/setup.py index 50ca78e..a501c58 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setuptools.setup( name="alignn", - version="2024.1.4", + version="2024.1.14", author="Kamal Choudhary, Brian DeCost", author_email="kamal.choudhary@nist.gov", description="alignn",