Skip to content

Commit

Permalink
Debug model_training.py to correct the differences between ensemble a…
Browse files Browse the repository at this point in the history
…nd predict_labelled predictions
  • Loading branch information
kspruthviraj committed Nov 24, 2022
1 parent aee5e7a commit a812395
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions utils/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ def import_deit_models_for_testing(self, train_main, test_main):
else:
print('This model cannot be imported. Please check from the list of models')

if torch.cuda.is_available() and train_main.params.use_gpu == 'yes':
device = torch.device("cuda:0")
if torch.cuda.is_available() and test_main.params.use_gpu == 'yes':
device = torch.device("cuda")
else:
device = torch.device("cpu")

Expand All @@ -133,8 +133,10 @@ def import_deit_models_for_testing(self, train_main, test_main):
print(f"{total_trainable_params:,} training parameters.")
class_weights_tensor = torch.load(test_main.params.main_param_path + '/class_weights_tensor.pt')
self.criterion = nn.CrossEntropyLoss(class_weights_tensor)
gpu_id = 1
if torch.cuda.is_available() and train_main.params.use_gpu == 'yes':

gpu_id = 1 # hard coded, but can be changed

if torch.cuda.is_available() and test_main.params.use_gpu == 'yes':
torch.cuda.set_device(gpu_id)
self.model.cuda(gpu_id)
self.criterion = self.criterion.cuda(gpu_id)
Expand Down

0 comments on commit a812395

Please sign in to comment.