Skip to content

Commit

Permalink
Debug model_training.py to correct the differences between ensemble a…
Browse files Browse the repository at this point in the history
…nd predict_labelled predictions
  • Loading branch information
kspruthviraj committed Nov 24, 2022
1 parent 9917b15 commit aee5e7a
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions utils/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,6 @@ def train_and_save(self, train_main, data_loader):
print('If you want to retrain then set "resume from saved" to "yes"')
self.run_prediction(data_loader, 'finetuned')

elif train_main.params.finetune == 0:
if not os.path.exists(model_present_path0):
self.train_predict(train_main, data_loader, 0)
else:
print('If you want to retrain then set "resume from saved" to "yes"')
self.run_prediction(data_loader, 'original')

elif train_main.params.finetune == 1:
if not os.path.exists(model_present_path0):
self.train_predict(train_main, data_loader, 0)
Expand All @@ -548,6 +541,13 @@ def train_and_save(self, train_main, data_loader):
print('If you want to retrain then set "resume from saved" to "yes"')
self.run_prediction(data_loader, 'tuned')

elif train_main.params.finetune == 0:
if not os.path.exists(model_present_path0):
self.train_predict(train_main, data_loader, 0)
else:
print('If you want to retrain then set "resume from saved" to "yes"')
self.run_prediction(data_loader, 'original')

elif train_main.params.resume_from_saved == 'yes':
if train_main.params.finetune == 0:
if not os.path.exists(model_present_path0):
Expand Down Expand Up @@ -1029,21 +1029,21 @@ def load_model_and_run_prediction(self, train_main, test_main, data_loader):
self.import_deit_models_for_testing(train_main, test_main)

if test_main.params.finetuned == 0:
# self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
if test_main.params.ensemble == 0:
self.run_prediction_on_unseen(test_main, data_loader, 'original')
else:
self.run_ensemble_prediction_on_unseen(test_main, data_loader, 'original')

elif test_main.params.finetuned == 1:
# self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
if test_main.params.ensemble == 0:
self.run_prediction_on_unseen(test_main, data_loader, 'tuned')
else:
self.run_ensemble_prediction_on_unseen(test_main, data_loader, 'tuned')

elif test_main.params.finetuned == 2:
# self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
if test_main.params.ensemble == 0:
self.run_prediction_on_unseen(test_main, data_loader, 'finetuned')
else:
Expand All @@ -1056,21 +1056,21 @@ def load_model_and_run_prediction_with_y(self, train_main, test_main, data_loade
self.import_deit_models_for_testing(train_main, test_main)

if test_main.params.finetuned == 0:
# self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
if test_main.params.ensemble == 0:
self.run_prediction_on_unseen_with_y(test_main, data_loader, 'original')
else:
self.run_ensemble_prediction_on_unseen_with_y(test_main, data_loader, 'original')

elif test_main.params.finetuned == 1:
# self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
if test_main.params.ensemble == 0:
self.run_prediction_on_unseen_with_y(test_main, data_loader, 'tuned')
else:
self.run_ensemble_prediction_on_unseen_with_y(test_main, data_loader, 'tuned')

elif test_main.params.finetuned == 2:
# self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
self.initialize_model(train_main, test_main, data_loader, train_main.params.lr)
if test_main.params.ensemble == 0:
self.run_prediction_on_unseen_with_y(test_main, data_loader, 'finetuned')
else:
Expand Down Expand Up @@ -1367,7 +1367,7 @@ def cls_predict_on_unseen_with_y(val_loader, model, criterion, time_begin=None):
images, target = images.to(device), target.to(device)
targets.append(target)

output, x = model(images)
output = model(images)
outputs.append(output)
prob = torch.nn.functional.softmax(output, dim=1)
probs.append(prob)
Expand Down

0 comments on commit aee5e7a

Please sign in to comment.