Skip to content

Commit

Permalink
remove optimizer from test()
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeBatch committed Mar 1, 2023
1 parent 217ac7f commit 59a62b5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train_tcga.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def dropout_patches(feats, p):
sampled_feats = np.concatenate((sampled_feats, pad_feats), axis=0)
return sampled_feats

def test(test_df, milnet, criterion, optimizer, args):
def test(test_df, milnet, criterion, args):
milnet.eval()
csvs = shuffle(test_df).reset_index(drop=True)
total_loss = 0
Expand Down Expand Up @@ -188,7 +188,7 @@ def main():
train_path = shuffle(train_path).reset_index(drop=True)
test_path = shuffle(test_path).reset_index(drop=True)
train_loss_bag = train(train_path, milnet, criterion, optimizer, args) # iterate all bags
test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path, milnet, criterion, optimizer, args)
test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path, milnet, criterion, args)
if args.dataset.startswith('TCGA-lung'):
print('\r Epoch [%d/%d] train loss: %.4f test loss: %.4f, average score: %.4f, auc_LUAD: %.4f, auc_LUSC: %.4f' %
(epoch, args.num_epochs, train_loss_bag, test_loss_bag, avg_score, aucs[0], aucs[1]))
Expand Down

0 comments on commit 59a62b5

Please sign in to comment.