Skip to content

Commit

Permalink
ensemble_model_prediction correction
Browse files Browse the repository at this point in the history
  • Loading branch information
kspruthviraj committed Jul 13, 2022
1 parent 0fbe642 commit 8efc98d
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions utils/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,13 @@ def run_prediction_on_unseen(self, test_main, data_loader, name):
output_label = np.array([classes[output_max[i]] for i in range(len(output_max))], dtype=object)

Pred_PredLabel_Prob = [output_max, output_label, prob]
with open(test_main.params.test_outpath + '/Pred_PredLabel_Prob' + name + '.pickle', 'wb') as cw:
with open(test_main.params.test_outpath + '/Single_model_Pred_PredLabel_Prob_' + name + '.pickle', 'wb') as cw:
pickle.dump(Pred_PredLabel_Prob, cw)

output_label = output_label.tolist()

To_write = [i + '------------------' + j + '\n' for i, j in zip(im_names[0], output_label)]
np.savetxt(test_main.params.test_outpath + '/Plankiformer_predictions.txt', To_write, fmt='%s')
np.savetxt(test_main.params.test_outpath + '/Single_model_Plankiformer_predictions.txt', To_write, fmt='%s')

def run_ensemble_prediction_on_unseen(self, test_main, data_loader, name):
classes = np.load(test_main.params.main_param_path + '/classes.npy')
Expand Down Expand Up @@ -390,27 +390,30 @@ def run_ensemble_prediction_on_unseen(self, test_main, data_loader, name):
Ens_DEIT_prob_max = []
Ens_DEIT_label = []
Ens_DEIT = []
name2 = []

if test_main.params.ensemble == 1:
Ens_DEIT = sum(Ensemble_prob) / len(Ensemble_prob)
Ens_DEIT_prob_max = Ens_DEIT.argmax(axis=1) # The class that the classifier would bet on
Ens_DEIT_label = np.array([classes[Ens_DEIT_prob_max[i]] for i in range(len(Ens_DEIT_prob_max))],
dtype=object)
name2 = 'arth_mean_'

elif test_main.params.ensemble == 1:
elif test_main.params.ensemble == 2:
Ens_DEIT = gmean(Ensemble_prob)
Ens_DEIT_prob_max = Ens_DEIT.argmax(axis=1) # The class that the classifier would bet on
Ens_DEIT_label = np.array([classes[Ens_DEIT_prob_max[i]] for i in range(len(Ens_DEIT_prob_max))],
dtype=object)
name2 = 'geo_mean_'

Pred_PredLabel_Prob = [Ens_DEIT_prob_max, Ens_DEIT_label, Ens_DEIT]
with open(test_main.params.test_outpath + '/Pred_PredLabel_Prob' + name + '.pickle', 'wb') as cw:
with open(test_main.params.test_outpath + '/Ensemble_models_Pred_PredLabel_Prob_' + name2 + name + '.pickle', 'wb') as cw:
pickle.dump(Pred_PredLabel_Prob, cw)

Ens_DEIT_label = Ens_DEIT_label.tolist()

To_write = [i + '------------------' + j + '\n' for i, j in zip(im_names[0], Ens_DEIT_label)]
np.savetxt(test_main.params.test_outpath + '/Plankiformer_predictions_ens.txt', To_write, fmt='%s')
np.savetxt(test_main.params.test_outpath + '/Ensemble_models_Plankiformer_predictions' + name2 + name + '.txt', To_write, fmt='%s')

def initialize_model(self, train_main, test_main, data_loader, lr):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand Down

0 comments on commit 8efc98d

Please sign in to comment.