diff --git a/compute_feats.py b/compute_feats.py index 3c69cc2..7f5356a 100644 --- a/compute_feats.py +++ b/compute_feats.py @@ -81,7 +81,7 @@ def compute_feats(args, bags_list, i_classifier, save_path=None, magnification=' os.makedirs(os.path.join(save_path, bags_list[i].split(os.path.sep)[-2]), exist_ok=True) df.to_csv(os.path.join(save_path, bags_list[i].split(os.path.sep)[-2], bags_list[i].split(os.path.sep)[-1]+'.csv'), index=False, float_format='%.4f') -def compute_tree_feats(args, bags_list, embedder_low, embedder_high, save_path=None, fusion='fusion'): +def compute_tree_feats(args, bags_list, embedder_low, embedder_high, save_path=None): embedder_low.eval() embedder_high.eval() num_bags = len(bags_list) @@ -107,10 +107,14 @@ def compute_tree_feats(args, bags_list, embedder_low, embedder_high, save_path=N img = Image.open(high_patch) img = VF.to_tensor(img).float().cuda() feats, classes = embedder_high(img[None, :]) - if fusion == 'fusion': + + if args.tree_fusion == 'fusion': feats = feats.cpu().numpy()+0.25*feats_list[idx] - if fusion == 'cat': + elif args.tree_fusion == 'cat': feats = np.concatenate((feats.cpu().numpy(), feats_list[idx][None, :]), axis=-1) + else: + raise NotImplementedError(f"{args.tree_fusion} is not an excepted option for --tree_fusion. This argument accepts 2 options: 'fusion' and 'cat'.") + feats_tree_list.extend(feats) sys.stdout.write('\r Computed: {}/{} -- {}/{}'.format(i+1, num_bags, idx+1, len(low_patches))) if len(feats_tree_list) == 0: @@ -133,6 +137,7 @@ def main(): parser.add_argument('--weights', default=None, type=str, help='Folder of the pretrained weights, simclr/runs/*') parser.add_argument('--weights_high', default=None, type=str, help='Folder of the pretrained weights of high magnification, FOLDER < `simclr/runs/[FOLDER]`') parser.add_argument('--weights_low', default=None, type=str, help='Folder of the pretrained weights of low magnification, FOLDER <`simclr/runs/[FOLDER]`') + parser.add_argument('--tree_fusion', default='cat', type=str, help='Fusion method for high and low mag features in a tree method [cat|fusion]') parser.add_argument('--dataset', default='TCGA-lung-single', type=str, help='Dataset folder name [TCGA-lung-single]') args = parser.parse_args() gpu_ids = tuple(args.gpu_index) @@ -238,7 +243,7 @@ def main(): bags_list = glob.glob(bags_path) if args.magnification == 'tree': - compute_tree_feats(args, bags_list, i_classifier_l, i_classifier_h, feats_path, 'cat') + compute_tree_feats(args, bags_list, i_classifier_l, i_classifier_h, feats_path) else: compute_feats(args, bags_list, i_classifier, feats_path, args.magnification) n_classes = glob.glob(os.path.join('datasets', args.dataset, '*'+os.path.sep)) diff --git a/deepzoom_tiler.py b/deepzoom_tiler.py index 0f90d3f..f55fd76 100644 --- a/deepzoom_tiler.py +++ b/deepzoom_tiler.py @@ -254,7 +254,7 @@ def nested_patches(img_slide, out_base, level=(0,), ext='jpeg'): parser.add_argument('-o', '--objective', type=float, default=20, help='The default objective power if metadata does not present [20]') parser.add_argument('-t', '--background_t', type=int, default=15, help='Threshold for filtering background [15]') args = parser.parse_args() - levels = tuple(args.magnifications) + levels = tuple(sorted(args.magnifications)) assert len(levels)<=2, 'Only 1 or 2 magnifications are supported!' path_base = os.path.join('WSI', args.dataset) if len(levels) == 2: diff --git a/train_tcga.py b/train_tcga.py index 4c3b209..45eea73 100644 --- a/train_tcga.py +++ b/train_tcga.py @@ -62,7 +62,7 @@ def dropout_patches(feats, p): sampled_feats = np.concatenate((sampled_feats, pad_feats), axis=0) return sampled_feats -def test(test_df, milnet, criterion, optimizer, args): +def test(test_df, milnet, criterion, args): milnet.eval() csvs = shuffle(test_df).reset_index(drop=True) total_loss = 0 @@ -188,8 +188,8 @@ def main(): train_path = shuffle(train_path).reset_index(drop=True) test_path = shuffle(test_path).reset_index(drop=True) train_loss_bag = train(train_path, milnet, criterion, optimizer, args) # iterate all bags - test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path, milnet, criterion, optimizer, args) - if args.dataset=='TCGA-lung': + test_loss_bag, avg_score, aucs, thresholds_optimal = test(test_path, milnet, criterion, args) + if args.dataset.startswith('TCGA-lung'): print('\r Epoch [%d/%d] train loss: %.4f test loss: %.4f, average score: %.4f, auc_LUAD: %.4f, auc_LUSC: %.4f' % (epoch, args.num_epochs, train_loss_bag, test_loss_bag, avg_score, aucs[0], aucs[1])) else: @@ -201,7 +201,7 @@ def main(): best_score = current_score save_name = os.path.join(save_path, str(run+1)+'.pth') torch.save(milnet.state_dict(), save_name) - if args.dataset=='TCGA-lung': + if args.dataset.startswith('TCGA-lung'): print('Best model saved at: ' + save_name + ' Best thresholds: LUAD %.4f, LUSC %.4f' % (thresholds_optimal[0], thresholds_optimal[1])) else: print('Best model saved at: ' + save_name)