Skip to content

Commit

Permalink
Merge pull request #771 from scap3yvt/767-add-option-to-bias-the-patc…
Browse files Browse the repository at this point in the history
…h-extraction-for-label-sampler

Added option to bias the patch extraction for label sampler
  • Loading branch information
sarthakpati authored Jan 18, 2024
2 parents ddf8fb9 + ad9709b commit c666e09
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 41 deletions.
19 changes: 10 additions & 9 deletions GANDLF/compute/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
parameters = populate_header_in_parameters(
parameters, headers_to_populate_train
)

# Calculate the weights here
(
parameters["weights"],
parameters["class_weights"],
) = get_class_imbalance_weights(parameters["training_data"], parameters)

print("Class weights : ", parameters["class_weights"])
print("Penalty weights: ", parameters["weights"])

# get the train loader
train_loader = get_train_loader(parameters)
parameters["training_samples_size"] = len(train_loader)
Expand Down Expand Up @@ -91,15 +101,6 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu

scheduler = get_scheduler(parameters)

# Calculate the weights here
(
parameters["weights"],
parameters["class_weights"],
) = get_class_imbalance_weights(parameters["training_data"], parameters)

print("Class weights : ", parameters["class_weights"])
print("Penalty weights: ", parameters["weights"])

else:
scheduler = None

Expand Down
55 changes: 32 additions & 23 deletions GANDLF/data/ImagesFromDataFrame.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@ def ImagesFromDataFrame(
q_samples_per_volume = parameters["q_samples_per_volume"]
q_num_workers = parameters["q_num_workers"]
q_verbose = parameters["q_verbose"]
sampler = parameters["patch_sampler"]
augmentations = parameters["data_augmentation"]
preprocessing = parameters["data_preprocessing"]
in_memory = parameters["in_memory"]
enable_padding = parameters["enable_padding"]
sampler = parameters["patch_sampler"]

# Finding the dimension of the dataframe for computational purposes later
num_row, num_col = dataframe.shape
Expand All @@ -83,14 +82,6 @@ def ImagesFromDataFrame(
predictionHeaders = headers["predictionHeaders"]
subjectIDHeader = headers["subjectIDHeader"]

# this basically means that label sampler is selected with padding
if isinstance(sampler, dict):
sampler_padding = sampler["label"]["padding_type"]
sampler = "label"
else:
sampler = sampler.lower() # for easier parsing
sampler_padding = "symmetric"

resize_images_flag = False
# if resize has been defined but resample is not (or is none)
if not (preprocessing is None):
Expand Down Expand Up @@ -251,14 +242,16 @@ def _save_resized_images(
)

# # padding image, but only for label sampler, because we don't want to pad for uniform
if "label" in sampler or "weight" in sampler:
if enable_padding:
psize_pad = list(
np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int)
)
# for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
padder = Pad(psize_pad, padding_mode=sampler_padding)
subject = padder(subject)
if sampler["enable_padding"]:
psize_pad = list(
np.asarray(np.ceil(np.divide(patch_size, 2)), dtype=int)
)
# ensure that the patch size for z-axis is not 1 for 2d images
if parameters["model"]["dimension"] == 2:
psize_pad[-1] = 0 if psize_pad[-1] == 1 else psize_pad[-1]
# for modes: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
padder = Pad(psize_pad, padding_mode=sampler["padding_mode"])
subject = padder(subject)

# load subject into memory: https://github.com/fepegar/torchio/discussions/568#discussioncomment-859027
if in_memory:
Expand Down Expand Up @@ -291,16 +284,32 @@ def _save_resized_images(
subjects_dataset = torchio.SubjectsDataset(subjects_list, transform=transform)
if not train:
return subjects_dataset
if sampler in ("weighted", "weightedsampler", "weightedsample"):
sampler = global_sampler_dict[sampler](patch_size, probability_map="label")
else:
sampler = global_sampler_dict[sampler](patch_size)

# initialize the sampler
sampler_obj = global_sampler_dict[sampler["type"]](patch_size)
if sampler["type"] in ("weighted", "weightedsampler", "weightedsample"):
sampler_obj = global_sampler_dict[sampler["type"]](
patch_size, probability_map="label"
)
elif sampler["type"] in ("label", "labelsampler", "labelsample"):
# if biased sampling is detected, then we need to pass the class probabilities
if train and sampler["biased_sampling"]:
# initialize the class probabilities dict
label_probabilities = {}
if "weights" in parameters:
for class_index in parameters["weights"]:
label_probabilities[class_index] = parameters["weights"][
class_index
]
sampler_obj = global_sampler_dict[sampler["type"]](
patch_size, label_probabilities=label_probabilities
)
# all of these need to be read from model.yaml
patches_queue = torchio.Queue(
subjects_dataset,
max_length=q_max_length,
samples_per_volume=q_samples_per_volume,
sampler=sampler,
sampler=sampler_obj,
num_workers=q_num_workers,
shuffle_subjects=True,
shuffle_patches=True,
Expand Down
26 changes: 24 additions & 2 deletions GANDLF/parseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"save_output": False, # save outputs during validation/testing
"in_memory": False, # pin data to cpu memory
"pin_memory_dataloader": False, # pin data to gpu memory
"enable_padding": False, # if padding needs to be done when "patch_sampler" is "label"
"scaling_factor": 1, # scaling factor for regression problems
"q_max_length": 100, # the max length of queue
"q_samples_per_volume": 10, # number of samples per volume
Expand All @@ -39,7 +38,6 @@
## dictionary to define string defaults for appropriate options
parameter_defaults_string = {
"optimizer": "adam", # the optimizer
"patch_sampler": "uniform", # type of sampling strategy
"scheduler": "triangle_modified", # the default scheduler
"clip_mode": None, # default clip mode
}
Expand Down Expand Up @@ -640,6 +638,30 @@ def parseConfig(config_file_path, version_check_flag=True):
print("DeprecationWarning: 'opt' has been superseded by 'optimizer'")
params["optimizer"] = params["opt"]

# initialize defaults for patch sampler
temp_patch_sampler_dict = {
"type": "uniform",
"enable_padding": False,
"padding_mode": "symmetric",
"biased_sampling": False,
}
# check if patch_sampler is defined in the config
if "patch_sampler" in params:
# if "patch_sampler" is a string, then it is the type of sampler
if isinstance(params["patch_sampler"], str):
print(
"WARNING: Defining 'patch_sampler' as a string will be deprecated in a future release, please use a dictionary instead"
)
temp_patch_sampler_dict["type"] = params["patch_sampler"].lower()
elif isinstance(params["patch_sampler"], dict):
# dict requires special handling
for key in params["patch_sampler"]:
temp_patch_sampler_dict[key] = params["patch_sampler"][key]

# now assign the dict back to the params
params["patch_sampler"] = temp_patch_sampler_dict
del temp_patch_sampler_dict

# define defaults
for current_parameter in parameter_defaults:
params = initialize_parameter(
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def get_class_imbalance_weights(training_df, params):
float, float: The penalty and class weights for different classes under consideration for classification.
"""
penalty_weights, class_weights = None, None
if params["weighted_loss"]:
if params["weighted_loss"] or params["patch_sampler"]["biased_sampling"]:
(penalty_weights, class_weights) = (
params.get("weights", None),
params.get("class_weights", None),
Expand Down
10 changes: 5 additions & 5 deletions samples/config_all_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ patch_sampler: uniform
# patch_sampler: label
# patch_sampler:
# {
# label:
# {
# padding_type: constant # how the label gets padded, for options, see 'mode' in https://numpy.org/doc/stable/reference/generated/numpy.pad.html
# }
# type: label,
# enable_padding: True,
# padding_mode: symmetric, # for options, see 'mode' in https://numpy.org/doc/stable/reference/generated/numpy.pad.html
# biased_sampling: True, # adds additional sampling probability of labels based on "class_weights"; only gets invoked when using label sampler.
# }
# If enabled, this parameter pads images and labels when label sampler is used
enable_padding: False
Expand Down Expand Up @@ -116,7 +116,7 @@ scheduler:
# use mse_torch for regression/classification problems and dice for segmentation
loss_function: dc
# this parameter weights the loss to handle imbalanced losses better
weighted_loss: True
weighted_loss: True # generates new keys "class_weights" and "weights" that handle the aggregate weights of the class and penalties per label, respectively
#loss_function:
# {
# 'mse':{
Expand Down
16 changes: 15 additions & 1 deletion testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,11 @@ def test_generic_cli_function_preprocess():
parameters["model"]["num_channels"] = 3
parameters["model"]["architecture"] = "unet"
parameters["metrics"] = ["dice"]
parameters["patch_sampler"] = "label"
parameters["patch_sampler"] = {
"type": "label",
"enable_padding": True,
"biased_sampling": True,
}
parameters["weighted_loss"] = True
parameters["save_output"] = True
parameters["data_preprocessing"]["to_canonical"] = None
Expand Down Expand Up @@ -2041,6 +2045,16 @@ def test_train_checkpointing_segmentation_rad_2d(device):
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
parameters["patch_sampler"] = {
"type": "label",
"enable_padding": True,
"biased_sampling": True,
}
file_config_temp = write_temp_config_path(parameters)
parameters = parseConfig(
file_config_temp, version_check_flag=False
)

training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
Expand Down

0 comments on commit c666e09

Please sign in to comment.