From a8123951c5644c36bf427d2af310a996c476b992 Mon Sep 17 00:00:00 2001 From: Sreenath P Kyathanahally Date: Thu, 24 Nov 2022 11:03:28 +0100 Subject: [PATCH] Debug model_training.py to correct the differences between ensemble and predict_labelled predictions --- utils/model_training.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/utils/model_training.py b/utils/model_training.py index 7ce9a2e..b96262d 100644 --- a/utils/model_training.py +++ b/utils/model_training.py @@ -117,8 +117,8 @@ def import_deit_models_for_testing(self, train_main, test_main): else: print('This model cannot be imported. Please check from the list of models') - if torch.cuda.is_available() and train_main.params.use_gpu == 'yes': - device = torch.device("cuda:0") + if torch.cuda.is_available() and test_main.params.use_gpu == 'yes': + device = torch.device("cuda") else: device = torch.device("cpu") @@ -133,8 +133,10 @@ def import_deit_models_for_testing(self, train_main, test_main): print(f"{total_trainable_params:,} training parameters.") class_weights_tensor = torch.load(test_main.params.main_param_path + '/class_weights_tensor.pt') self.criterion = nn.CrossEntropyLoss(class_weights_tensor) - gpu_id = 1 - if torch.cuda.is_available() and train_main.params.use_gpu == 'yes': + + gpu_id = 1 # hard coded, but can be changed + + if torch.cuda.is_available() and test_main.params.use_gpu == 'yes': torch.cuda.set_device(gpu_id) self.model.cuda(gpu_id) self.criterion = self.criterion.cuda(gpu_id)