diff --git a/GANDLF/compute/generic.py b/GANDLF/compute/generic.py index d3674f42e..6b171cc43 100644 --- a/GANDLF/compute/generic.py +++ b/GANDLF/compute/generic.py @@ -45,6 +45,16 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu parameters = populate_header_in_parameters( parameters, headers_to_populate_train ) + + # Calculate the weights here + ( + parameters["weights"], + parameters["class_weights"], + ) = get_class_imbalance_weights(parameters["training_data"], parameters) + + print("Class weights : ", parameters["class_weights"]) + print("Penalty weights: ", parameters["weights"]) + # get the train loader train_loader = get_train_loader(parameters) parameters["training_samples_size"] = len(train_loader) @@ -91,15 +101,6 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu scheduler = get_scheduler(parameters) - # Calculate the weights here - ( - parameters["weights"], - parameters["class_weights"], - ) = get_class_imbalance_weights(parameters["training_data"], parameters) - - print("Class weights : ", parameters["class_weights"]) - print("Penalty weights: ", parameters["weights"]) - else: scheduler = None diff --git a/GANDLF/data/ImagesFromDataFrame.py b/GANDLF/data/ImagesFromDataFrame.py index 4ce778b0c..68416d7c4 100644 --- a/GANDLF/data/ImagesFromDataFrame.py +++ b/GANDLF/data/ImagesFromDataFrame.py @@ -63,11 +63,10 @@ def ImagesFromDataFrame( q_samples_per_volume = parameters["q_samples_per_volume"] q_num_workers = parameters["q_num_workers"] q_verbose = parameters["q_verbose"] - sampler = parameters["patch_sampler"] augmentations = parameters["data_augmentation"] preprocessing = parameters["data_preprocessing"] in_memory = parameters["in_memory"] - enable_padding = parameters["enable_padding"] + sampler = parameters["patch_sampler"] # Finding the dimension of the dataframe for computational purposes later num_row, num_col = dataframe.shape @@ -83,14 +82,6 @@ def ImagesFromDataFrame( predictionHeaders = headers["predictionHeaders"] subjectIDHeader = headers["subjectIDHeader"] - # this basically means that label sampler is selected with padding - if isinstance(sampler, dict): - sampler_padding = sampler["label"]["padding_type"] - sampler = "label" - else: - sampler = sampler.lower() # for easier parsing - sampler_padding = "symmetric" - resize_images_flag = False # if resize has been defined but resample is not (or is none) if not (preprocessing is None): @@ -251,14 +242,16 @@ def _save_resized_images( ) # # padding image, but only for label sampler, because we don't want to pad for uniform - if "label" in sampler or "weight" in sampler: - if enable_padding: - psize_pad = list( - np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int) - ) - # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html - padder = Pad(psize_pad, padding_mode=sampler_padding) - subject = padder(subject) + if sampler["enable_padding"]: + psize_pad = list( + np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int) + ) + # ensure that the patch size for z-axis is not 1 for 2d images + if parameters["model"]["dimension"] == 2: + psize_pad[-1] = 0 if psize_pad[-1] == 1 else psize_pad[-1] + # for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html + padder = Pad(psize_pad, padding_mode=sampler["padding_mode"]) + subject = padder(subject) # load subject into memory: https://github.com/fepegar/torchio/discussions/568#discussioncomment-859027 if in_memory: @@ -291,16 +284,32 @@ def _save_resized_images( subjects_dataset = torchio.SubjectsDataset(subjects_list, transform=transform) if not train: return subjects_dataset - if sampler in ("weighted", "weightedsampler", "weightedsample"): - sampler = global_sampler_dict[sampler](patch_size, probability_map="label") - else: - sampler = global_sampler_dict[sampler](patch_size) + + # initialize the sampler + sampler_obj = global_sampler_dict[sampler["type"]](patch_size) + if sampler["type"] in ("weighted", "weightedsampler", "weightedsample"): + sampler_obj = global_sampler_dict[sampler["type"]]( + patch_size, probability_map="label" + ) + elif sampler["type"] in ("label", "labelsampler", "labelsample"): + # if biased sampling is detected, then we need to pass the class probabilities + if train and sampler["biased_sampling"]: + # initialize the class probabilities dict + label_probabilities = {} + if "weights" in parameters: + for class_index in parameters["weights"]: + label_probabilities[class_index] = parameters["weights"][ + class_index + ] + sampler_obj = global_sampler_dict[sampler["type"]]( + patch_size, label_probabilities=label_probabilities + ) # all of these need to be read from model.yaml patches_queue = torchio.Queue( subjects_dataset, max_length=q_max_length, samples_per_volume=q_samples_per_volume, - sampler=sampler, + sampler=sampler_obj, num_workers=q_num_workers, shuffle_subjects=True, shuffle_patches=True, diff --git a/GANDLF/parseConfig.py b/GANDLF/parseConfig.py index b62d7b371..f839fe636 100644 --- a/GANDLF/parseConfig.py +++ b/GANDLF/parseConfig.py @@ -17,7 +17,6 @@ "save_output": False, # save outputs during validation/testing "in_memory": False, # pin data to cpu memory "pin_memory_dataloader": False, # pin data to gpu memory - "enable_padding": False, # if padding needs to be done when "patch_sampler" is "label" "scaling_factor": 1, # scaling factor for regression problems "q_max_length": 100, # the max length of queue "q_samples_per_volume": 10, # number of samples per volume @@ -39,7 +38,6 @@ ## dictionary to define string defaults for appropriate options parameter_defaults_string = { "optimizer": "adam", # the optimizer - "patch_sampler": "uniform", # type of sampling strategy "scheduler": "triangle_modified", # the default scheduler "clip_mode": None, # default clip mode } @@ -640,6 +638,30 @@ def parseConfig(config_file_path, version_check_flag=True): print("DeprecationWarning: 'opt' has been superseded by 'optimizer'") params["optimizer"] = params["opt"] + # initialize defaults for patch sampler + temp_patch_sampler_dict = { + "type": "uniform", + "enable_padding": False, + "padding_mode": "symmetric", + "biased_sampling": False, + } + # check if patch_sampler is defined in the config + if "patch_sampler" in params: + # if "patch_sampler" is a string, then it is the type of sampler + if isinstance(params["patch_sampler"], str): + print( + "WARNING: Defining 'patch_sampler' as a string will be deprecated in a future release, please use a dictionary instead" + ) + temp_patch_sampler_dict["type"] = params["patch_sampler"].lower() + elif isinstance(params["patch_sampler"], dict): + # dict requires special handling + for key in params["patch_sampler"]: + temp_patch_sampler_dict[key] = params["patch_sampler"][key] + + # now assign the dict back to the params + params["patch_sampler"] = temp_patch_sampler_dict + del temp_patch_sampler_dict + # define defaults for current_parameter in parameter_defaults: params = initialize_parameter( diff --git a/GANDLF/utils/tensor.py b/GANDLF/utils/tensor.py index da9b26cb1..7027234eb 100644 --- a/GANDLF/utils/tensor.py +++ b/GANDLF/utils/tensor.py @@ -364,7 +364,7 @@ def get_class_imbalance_weights(training_df, params): float, float: The penalty and class weights for different classes under consideration for classification. """ penalty_weights, class_weights = None, None - if params["weighted_loss"]: + if params["weighted_loss"] or params["patch_sampler"]["biased_sampling"]: (penalty_weights, class_weights) = ( params.get("weights", None), params.get("class_weights", None), diff --git a/samples/config_all_options.yaml b/samples/config_all_options.yaml index d1e046a7a..2348605f1 100644 --- a/samples/config_all_options.yaml +++ b/samples/config_all_options.yaml @@ -82,10 +82,10 @@ patch_sampler: uniform # patch_sampler: label # patch_sampler: # { -# label: -# { -# padding_type: constant # how the label gets padded, for options, see 'mode' in https://numpy.org/doc/stable/reference/generated/numpy.pad.html -# } +# type: label, +# enable_padding: True, +# padding_mode: symmetric, # for options, see 'mode' in https://numpy.org/doc/stable/reference/generated/numpy.pad.html +# biased_sampling: True, # adds additional sampling probability of labels based on "class_weights"; only gets invoked when using label sampler. # } # If enabled, this parameter pads images and labels when label sampler is used enable_padding: False @@ -116,7 +116,7 @@ scheduler: # use mse_torch for regression/classification problems and dice for segmentation loss_function: dc # this parameter weights the loss to handle imbalanced losses better -weighted_loss: True +weighted_loss: True # generates new keys "class_weights" and "weights" that handle the aggregate weights of the class and penalties per label, respectively #loss_function: # { # 'mse':{ diff --git a/testing/test_full.py b/testing/test_full.py index 772258320..52920a772 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -1465,7 +1465,11 @@ def test_generic_cli_function_preprocess(): parameters["model"]["num_channels"] = 3 parameters["model"]["architecture"] = "unet" parameters["metrics"] = ["dice"] - parameters["patch_sampler"] = "label" + parameters["patch_sampler"] = { + "type": "label", + "enable_padding": True, + "biased_sampling": True, + } parameters["weighted_loss"] = True parameters["save_output"] = True parameters["data_preprocessing"]["to_canonical"] = None @@ -2041,6 +2045,16 @@ def test_train_checkpointing_segmentation_rad_2d(device): parameters = parseConfig( testingDir + "/config_segmentation.yaml", version_check_flag=False ) + parameters["patch_sampler"] = { + "type": "label", + "enable_padding": True, + "biased_sampling": True, + } + file_config_temp = write_temp_config_path(parameters) + parameters = parseConfig( + file_config_temp, version_check_flag=False + ) + training_data, parameters["headers"] = parseTrainingCSV( inputDir + "/train_2d_rad_segmentation.csv" )