diff --git a/utils/model_training.py b/utils/model_training.py index 9995362..7ce9a2e 100644 --- a/utils/model_training.py +++ b/utils/model_training.py @@ -527,13 +527,6 @@ def train_and_save(self, train_main, data_loader): print('If you want to retrain then set "resume from saved" to "yes"') self.run_prediction(data_loader, 'finetuned') - elif train_main.params.finetune == 0: - if not os.path.exists(model_present_path0): - self.train_predict(train_main, data_loader, 0) - else: - print('If you want to retrain then set "resume from saved" to "yes"') - self.run_prediction(data_loader, 'original') - elif train_main.params.finetune == 1: if not os.path.exists(model_present_path0): self.train_predict(train_main, data_loader, 0) @@ -548,6 +541,13 @@ def train_and_save(self, train_main, data_loader): print('If you want to retrain then set "resume from saved" to "yes"') self.run_prediction(data_loader, 'tuned') + elif train_main.params.finetune == 0: + if not os.path.exists(model_present_path0): + self.train_predict(train_main, data_loader, 0) + else: + print('If you want to retrain then set "resume from saved" to "yes"') + self.run_prediction(data_loader, 'original') + elif train_main.params.resume_from_saved == 'yes': if train_main.params.finetune == 0: if not os.path.exists(model_present_path0): @@ -1029,21 +1029,21 @@ def load_model_and_run_prediction(self, train_main, test_main, data_loader): self.import_deit_models_for_testing(train_main, test_main) if test_main.params.finetuned == 0: - # self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) + self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) if test_main.params.ensemble == 0: self.run_prediction_on_unseen(test_main, data_loader, 'original') else: self.run_ensemble_prediction_on_unseen(test_main, data_loader, 'original') elif test_main.params.finetuned == 1: - # self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) + self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) if test_main.params.ensemble == 0: self.run_prediction_on_unseen(test_main, data_loader, 'tuned') else: self.run_ensemble_prediction_on_unseen(test_main, data_loader, 'tuned') elif test_main.params.finetuned == 2: - # self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) + self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) if test_main.params.ensemble == 0: self.run_prediction_on_unseen(test_main, data_loader, 'finetuned') else: @@ -1056,21 +1056,21 @@ def load_model_and_run_prediction_with_y(self, train_main, test_main, data_loade self.import_deit_models_for_testing(train_main, test_main) if test_main.params.finetuned == 0: - # self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) + self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) if test_main.params.ensemble == 0: self.run_prediction_on_unseen_with_y(test_main, data_loader, 'original') else: self.run_ensemble_prediction_on_unseen_with_y(test_main, data_loader, 'original') elif test_main.params.finetuned == 1: - # self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) + self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) if test_main.params.ensemble == 0: self.run_prediction_on_unseen_with_y(test_main, data_loader, 'tuned') else: self.run_ensemble_prediction_on_unseen_with_y(test_main, data_loader, 'tuned') elif test_main.params.finetuned == 2: - # self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) + self.initialize_model(train_main, test_main, data_loader, train_main.params.lr) if test_main.params.ensemble == 0: self.run_prediction_on_unseen_with_y(test_main, data_loader, 'finetuned') else: @@ -1367,7 +1367,7 @@ def cls_predict_on_unseen_with_y(val_loader, model, criterion, time_begin=None): images, target = images.to(device), target.to(device) targets.append(target) - output, x = model(images) + output = model(images) outputs.append(output) prob = torch.nn.functional.softmax(output, dim=1) probs.append(prob)