From 6125adc027bbc14c3b208ea6ec8441de4d4fe3fb Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Mon, 22 Jan 2024 19:33:54 +0000 Subject: [PATCH 1/6] added comments, and no need to patch the images in offline mode --- GANDLF/cli/preprocess_and_save.py | 43 ++++++++++++++----------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/GANDLF/cli/preprocess_and_save.py b/GANDLF/cli/preprocess_and_save.py index 1da202ae7..c35f307fe 100644 --- a/GANDLF/cli/preprocess_and_save.py +++ b/GANDLF/cli/preprocess_and_save.py @@ -8,6 +8,7 @@ parseTrainingCSV, populate_header_in_parameters, get_dataframe, + get_correct_padding_size, ) from GANDLF.parseConfig import parseConfig from GANDLF.data.ImagesFromDataFrame import ImagesFromDataFrame @@ -18,8 +19,13 @@ def preprocess_and_save( - data_csv, config_file, output_dir, label_pad_mode="constant", applyaugs=False -): + data_csv: str, + config_file: str, + output_dir: str, + label_pad_mode: str = "constant", + applyaugs: bool = False, + apply_zero_crop: bool = False, +) -> None: """ This function performs preprocessing based on parameters provided and saves the output. @@ -55,7 +61,11 @@ def preprocess_and_save( parameters = populate_header_in_parameters(parameters, headers) data_for_processing = ImagesFromDataFrame( - dataframe, parameters, train=applyaugs, apply_zero_crop=True, loader_type="full" + dataframe, + parameters, + train=applyaugs, + apply_zero_crop=apply_zero_crop, + loader_type="full", ) dataloader_for_processing = DataLoader( @@ -117,34 +127,19 @@ def preprocess_and_save( subject_dict_to_write = torchio.Subject(subject_process) # apply a different padding mode to image and label (so that label information is not duplicated) - if (parameters["patch_sampler"] == "label") or ( - isinstance(parameters["patch_sampler"], dict) - ): + if parameters["patch_sampler"]["type"] == "label": # get the padding size from the patch_size - psize_pad = list( - np.asarray(np.ceil(np.divide(parameters["patch_size"], 2)), dtype=int) + psize_pad = get_correct_padding_size( + parameters["patch_size"], parameters["model"]["dimension"] ) # initialize the padder for images padder = torchio.transforms.Pad( - psize_pad, padding_mode="symmetric", include=keys_with_images + psize_pad, + padding_mode=label_pad_mode, + include=keys_with_images + ["label"], ) subject_dict_to_write = padder(subject_dict_to_write) - if parameters["headers"]["labelHeader"] is not None: - # initialize the padder for label - padder_label = torchio.transforms.Pad( - psize_pad, padding_mode=label_pad_mode, include="label" - ) - subject_dict_to_write = padder_label(subject_dict_to_write) - - sampler = torchio.data.LabelSampler(parameters["patch_size"]) - generator = sampler(subject_dict_to_write, num_patches=1) - for patch in generator: - for channel in parameters["headers"]["channelHeaders"]: - subject_dict_to_write[str(channel)] = patch[str(channel)] - - subject_dict_to_write["label"] = patch["label"] - # write new images common_ext = get_filename_extension_sanitized(subject["path_to_metadata"][0]) # in cases where the original image has a file format that does not support From 1f69eef98616a1f79d84fba6aaf4305c7179e0b4 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Mon, 22 Jan 2024 19:34:40 +0000 Subject: [PATCH 2/6] added new option --- gandlf_preprocess | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/gandlf_preprocess b/gandlf_preprocess index 39cee20b3..1522f8607 100644 --- a/gandlf_preprocess +++ b/gandlf_preprocess @@ -55,11 +55,25 @@ if __name__ == "__main__": help="This specifies the whether to apply data augmentation during output creation. Defaults to False", required=False, ) + parser.add_argument( + "-a", + "--cropzero", + metavar="", + type=bool, + default=False, + help="This specifies the whether to apply zero cropping during output creation. Defaults to False", + required=False, + ) args = parser.parse_args() preprocess_and_save( - args.inputdata, args.config, args.output, args.labelPad, args.applyaugs + args.inputdata, + args.config, + args.output, + args.labelPad, + args.applyaugs, + args.cropzero, ) print("Finished.") From 52fa5b10bed977b7463b09a53c01bcec88904245 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Mon, 22 Jan 2024 19:36:07 +0000 Subject: [PATCH 3/6] added new function to calculate patch size --- GANDLF/utils/__init__.py | 1 + GANDLF/utils/imaging.py | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/GANDLF/utils/__init__.py b/GANDLF/utils/__init__.py index 0ecc697ed..39173e2b9 100644 --- a/GANDLF/utils/__init__.py +++ b/GANDLF/utils/__init__.py @@ -8,6 +8,7 @@ resample_image, perform_sanity_check_on_subject, write_training_patches, + get_correct_padding_size, ) from .tensor import ( diff --git a/GANDLF/utils/imaging.py b/GANDLF/utils/imaging.py index 94646b014..37e88a98c 100644 --- a/GANDLF/utils/imaging.py +++ b/GANDLF/utils/imaging.py @@ -1,4 +1,5 @@ -import os, pathlib, sys, math +from typing import Union +import os, pathlib, math, copy import numpy as np import SimpleITK as sitk import torchio @@ -126,8 +127,6 @@ def perform_sanity_check_on_subject(subject, parameters): # read the first image and save that for comparison file_reader_base = None - import copy - list_for_comparison = copy.deepcopy(parameters["headers"]["channelHeaders"]) if parameters["headers"]["labelHeader"] is not None: list_for_comparison.append("label") @@ -240,3 +239,22 @@ def write_training_patches(subject, params): img_to_write, os.path.join(training_output_dir_current_subject, "label_" + key + ext), ) + + +def get_correct_padding_size(patch_size: Union[list, tuple], model_dimension: int): + """ + This function returns the correct padding size based on the patch size and overlap. + + Args: + patch_size (Union[list, tuple]): The patch size. + model_dimension (int): The model dimension. + + Returns: + Union[list, tuple]: The correct padding size. + """ + psize_pad = list(np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int)) + # ensure that the patch size for z-axis is not 1 for 2d images + if model_dimension == 2: + psize_pad[-1] = 0 if psize_pad[-1] == 1 else psize_pad[-1] + + return psize_pad From 1e423d340779712f0899fa41a6d17a2b962163d7 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Mon, 22 Jan 2024 19:36:38 +0000 Subject: [PATCH 4/6] using new function --- GANDLF/data/ImagesFromDataFrame.py | 97 ++++++++++++++++-------------- 1 file changed, 51 insertions(+), 46 deletions(-) diff --git a/GANDLF/data/ImagesFromDataFrame.py b/GANDLF/data/ImagesFromDataFrame.py index 4ce778b0c..70790f1fb 100644 --- a/GANDLF/data/ImagesFromDataFrame.py +++ b/GANDLF/data/ImagesFromDataFrame.py @@ -1,7 +1,9 @@ +from typing import Union import os from pathlib import Path import numpy as np +import pandas import torch import torchio from torchio.transforms import Pad @@ -12,6 +14,7 @@ perform_sanity_check_on_subject, resize_image, get_filename_extension_sanitized, + get_correct_padding_size, ) from .preprocessing import get_transforms_for_preprocessing from .augmentation import global_augs_dict @@ -31,30 +34,27 @@ # This function takes in a dataframe, with some other parameters and returns the dataloader def ImagesFromDataFrame( - dataframe, parameters, train, apply_zero_crop=False, loader_type="" -): + dataframe: pandas.DataFrame, + parameters: dict, + train: bool, + apply_zero_crop: bool = False, + loader_type: str = "", +) -> Union[torchio.SubjectsDataset, torchio.Queue]: """ - Reads the pandas dataframe and gives the dataloader to use for training/validation/testing - - Parameters - ---------- - dataframe : pandas.DataFrame - The main input dataframe which is calculated after splitting the data CSV - parameters : dict - The parameters dictionary - train : bool - If the dataloader is for training or not. For training, the patching infrastructure and data augmentation is applied. - apply_zero_crop : bool - If enabled, the crop_external_zero_plane is applied. - loader_type : str - Type of loader for printing. - - Returns - ------- - subjects_dataset: torchio.SubjectsDataset - This is the output for validation/testing, where patching and data augmentation is disregarded - patches_queue: torchio.Queue - This is the output for training, which is the subjects_dataset queue after patching and data augmentation is taken into account + Reads the pandas dataframe and gives the dataloader to use for training/validation/testing. + + Args: + dataframe (pandas.DataFrame): The main input dataframe which is calculated after splitting the data CSV. + parameters (dict): The parameters dictionary. + train (bool): If the dataloader is for training or not. + apply_zero_crop (bool, optional): If zero crop is to be applied. Defaults to False. + loader_type (str, optional): The type of loader to use for printing. Defaults to "". + + Raises: + ValueError: If the subject cannot be loaded. + + Returns: + Union[torchio.SubjectsDataset, torchio.Queue]: The dataloader queue for validation/testing (where patching and data augmentation is not required) or the subjects dataset for training. """ # store in previous variable names patch_size = parameters["patch_size"] @@ -63,11 +63,10 @@ def ImagesFromDataFrame( q_samples_per_volume = parameters["q_samples_per_volume"] q_num_workers = parameters["q_num_workers"] q_verbose = parameters["q_verbose"] - sampler = parameters["patch_sampler"] augmentations = parameters["data_augmentation"] preprocessing = parameters["data_preprocessing"] in_memory = parameters["in_memory"] - enable_padding = parameters["enable_padding"] + sampler = parameters["patch_sampler"] # Finding the dimension of the dataframe for computational purposes later num_row, num_col = dataframe.shape @@ -83,14 +82,6 @@ def ImagesFromDataFrame( predictionHeaders = headers["predictionHeaders"] subjectIDHeader = headers["subjectIDHeader"] - # this basically means that label sampler is selected with padding - if isinstance(sampler, dict): - sampler_padding = sampler["label"]["padding_type"] - sampler = "label" - else: - sampler = sampler.lower() # for easier parsing - sampler_padding = "symmetric" - resize_images_flag = False # if resize has been defined but resample is not (or is none) if not (preprocessing is None): @@ -251,14 +242,12 @@ def _save_resized_images( ) # # padding image, but only for label sampler, because we don't want to pad for uniform - if "label" in sampler or "weight" in sampler: - if enable_padding: - psize_pad = list( - np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int) - ) - # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html - padder = Pad(psize_pad, padding_mode=sampler_padding) - subject = padder(subject) + if sampler["enable_padding"]: + psize_pad = get_correct_padding_size( + patch_size, parameters["model"]["dimension"] + ) + padder = Pad(psize_pad, padding_mode=sampler["padding_mode"]) + subject = padder(subject) # load subject into memory: https://github.com/fepegar/torchio/discussions/568#discussioncomment-859027 if in_memory: @@ -291,16 +280,32 @@ def _save_resized_images( subjects_dataset = torchio.SubjectsDataset(subjects_list, transform=transform) if not train: return subjects_dataset - if sampler in ("weighted", "weightedsampler", "weightedsample"): - sampler = global_sampler_dict[sampler](patch_size, probability_map="label") - else: - sampler = global_sampler_dict[sampler](patch_size) + + # initialize the sampler + sampler_obj = global_sampler_dict[sampler["type"]](patch_size) + if sampler["type"] in ("weighted", "weightedsampler", "weightedsample"): + sampler_obj = global_sampler_dict[sampler["type"]]( + patch_size, probability_map="label" + ) + elif sampler["type"] in ("label", "labelsampler", "labelsample"): + # if biased sampling is detected, then we need to pass the class probabilities + if train and sampler["biased_sampling"]: + # initialize the class probabilities dict + label_probabilities = {} + if "sampling_weights" in parameters: + for class_index in parameters["sampling_weights"]: + label_probabilities[class_index] = parameters["sampling_weights"][ + class_index + ] + sampler_obj = global_sampler_dict[sampler["type"]]( + patch_size, label_probabilities=label_probabilities + ) # all of these need to be read from model.yaml patches_queue = torchio.Queue( subjects_dataset, max_length=q_max_length, samples_per_volume=q_samples_per_volume, - sampler=sampler, + sampler=sampler_obj, num_workers=q_num_workers, shuffle_subjects=True, shuffle_patches=True, From d8c065136dd3006252608fbaec8025d6d01ee230 Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Mon, 22 Jan 2024 19:39:30 +0000 Subject: [PATCH 5/6] using new function --- GANDLF/data/ImagesFromDataFrame.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/GANDLF/data/ImagesFromDataFrame.py b/GANDLF/data/ImagesFromDataFrame.py index 14ffb5a01..70790f1fb 100644 --- a/GANDLF/data/ImagesFromDataFrame.py +++ b/GANDLF/data/ImagesFromDataFrame.py @@ -243,13 +243,9 @@ def _save_resized_images( # # padding image, but only for label sampler, because we don't want to pad for uniform if sampler["enable_padding"]: - psize_pad = list( - np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int) + psize_pad = get_correct_padding_size( + patch_size, parameters["model"]["dimension"] ) - # ensure that the patch size for z-axis is not 1 for 2d images - if parameters["model"]["dimension"] == 2: - psize_pad[-1] = 0 if psize_pad[-1] == 1 else psize_pad[-1] - # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html padder = Pad(psize_pad, padding_mode=sampler["padding_mode"]) subject = padder(subject) From abad6db8948a28374553b75a6e9e09740000859d Mon Sep 17 00:00:00 2001 From: scap3yvt <149599669+scap3yvt@users.noreply.github.com> Date: Mon, 22 Jan 2024 20:23:44 +0000 Subject: [PATCH 6/6] added comment --- GANDLF/cli/preprocess_and_save.py | 1 + 1 file changed, 1 insertion(+) diff --git a/GANDLF/cli/preprocess_and_save.py b/GANDLF/cli/preprocess_and_save.py index c35f307fe..1309c331c 100644 --- a/GANDLF/cli/preprocess_and_save.py +++ b/GANDLF/cli/preprocess_and_save.py @@ -35,6 +35,7 @@ def preprocess_and_save( output_dir (str): The output directory. label_pad_mode (str): The padding strategy for the label. Defaults to "constant". applyaugs (bool): If data augmentation is to be applied before saving the image. Defaults to False. + apply_zero_crop (bool): If zero cropping is to be applied before saving the image. Defaults to False. Raises: ValueError: Parameter check from previous