Skip to content

Commit

Permalink
Merge pull request #777 from scap3yvt/776-add-option-to-use-different…
Browse files Browse the repository at this point in the history
…-weights-for-sampler-biasing-and-loss

Added option to use different weights for sampler biasing and loss
  • Loading branch information
sarthakpati authored Jan 18, 2024
2 parents c666e09 + 0853b68 commit 2d5513a
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 36 deletions.
3 changes: 2 additions & 1 deletion GANDLF/cli/recover_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def recover_config(modelDir, outputFile):
"testing_data",
"device",
"subject_spacing",
"weights",
"penalty_weights",
"sampling_weights",
"class_weights",
]

Expand Down
8 changes: 5 additions & 3 deletions GANDLF/compute/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu

# Calculate the weights here
(
parameters["weights"],
parameters["penalty_weights"],
parameters["sampling_weights"],
parameters["class_weights"],
) = get_class_imbalance_weights(parameters["training_data"], parameters)

print("Class weights : ", parameters["class_weights"])
print("Penalty weights: ", parameters["weights"])
print("Penalty weights : ", parameters["penalty_weights"])
print("Sampling weights: ", parameters["sampling_weights"])
print("Class weights : ", parameters["class_weights"])

# get the train loader
train_loader = get_train_loader(parameters)
Expand Down
6 changes: 3 additions & 3 deletions GANDLF/data/ImagesFromDataFrame.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ def _save_resized_images(
if train and sampler["biased_sampling"]:
# initialize the class probabilities dict
label_probabilities = {}
if "weights" in parameters:
for class_index in parameters["weights"]:
label_probabilities[class_index] = parameters["weights"][
if "sampling_weights" in parameters:
for class_index in parameters["sampling_weights"]:
label_probabilities[class_index] = parameters["sampling_weights"][
class_index
]
sampler_obj = global_sampler_dict[sampler["type"]](
Expand Down
8 changes: 3 additions & 5 deletions GANDLF/inference_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@ def InferenceManager(dataframe, modelDir, parameters, device, outputDir=None):

parameters["output_dir"] = outputDir

# # initialize parameters for inference
if not ("weights" in parameters):
parameters["weights"] = None # no need for loss weights for inference
if not ("class_weights" in parameters):
parameters["class_weights"] = None # no need for class weights for inference
# initialize parameters for inference
for key in ["penalty_weights", "sampling_weights", "class_weights"]:
parameters[key] = parameters.get(key, None)

n_folds = parameters["nested_training"]["validation"]

Expand Down
12 changes: 6 additions & 6 deletions GANDLF/losses/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def CEL(prediction, target, params):
target = torch.squeeze(target, -1)

weights = None
if params.get("weights") is not None:
if params.get("penalty_weights") is not None:
# Check that the number of classes matches the number of weights
num_classes = len(params["weights"])
num_classes = len(params["penalty_weights"])
assert (
prediction.shape[-1] == num_classes
), f"Number of classes {num_classes} does not match prediction shape {prediction.shape[-1]}"

weights = torch.FloatTensor(list(params["weights"].values()))
weights = torch.FloatTensor(list(params["penalty_weights"].values()))
weights = weights.float().to(target.device)

cel = CrossEntropyLoss(weight=weights)
Expand Down Expand Up @@ -93,12 +93,12 @@ def CCE_Generic(prediction, target, params, CCE_Type):

for i in range(0, len(params["model"]["class_list"])):
curr_ce_loss = CCE_Type(prediction[:, i, ...], target[:, i, ...])
if params["weights"] is not None:
curr_ce_loss = curr_ce_loss * params["weights"][i]
if params["penalty_weights"] is not None:
curr_ce_loss = curr_ce_loss * params["penalty_weights"][i]
acc_ce_loss += curr_ce_loss

# Take the mean of the loss if weights are not provided.
if params["weights"] is None:
if params["penalty_weights"] is None:
acc_ce_loss = torch.mean(acc_ce_loss)

return acc_ce_loss
Expand Down
18 changes: 9 additions & 9 deletions GANDLF/losses/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def MCD_loss(
target,
len(params["model"]["class_list"]),
dice,
params["weights"],
params["penalty_weights"],
None,
1,
)
Expand All @@ -157,7 +157,7 @@ def MCD_log_loss(
target,
len(params["model"]["class_list"]),
dice,
params["weights"],
params["penalty_weights"],
None,
2,
)
Expand All @@ -182,7 +182,7 @@ def MCC_loss(
target,
len(params["model"]["class_list"]),
mcc,
params["weights"],
params["penalty_weights"],
None,
1,
)
Expand All @@ -207,7 +207,7 @@ def MCC_log_loss(
target,
len(params["model"]["class_list"]),
mcc,
params["weights"],
params["penalty_weights"],
None,
2,
)
Expand Down Expand Up @@ -268,11 +268,11 @@ def MCT_loss(

for i in range(num_classes):
curr_loss = tversky_loss(predicted[:, i, ...], target[:, i, ...])
if params is not None and params.get("weights") is not None:
curr_loss = curr_loss * params["weights"][i]
if params is not None and params.get("penalty_weights") is not None:
curr_loss = curr_loss * params["penalty_weights"][i]
acc_tv_loss += curr_loss

if params is not None and params.get("weights") is None:
if params is not None and params.get("penalty_weights") is None:
acc_tv_loss /= num_classes

return acc_tv_loss
Expand Down Expand Up @@ -343,8 +343,8 @@ def _focal_loss(preds, target, gamma, size_average=True):
curr_loss = _focal_loss(
predicted[:, i, ...], target[:, i, ...], gamma, size_average
)
if params is not None and params.get("weights") is not None:
curr_loss = curr_loss * params["weights"][i]
if params is not None and params.get("penalty_weights") is not None:
curr_loss = curr_loss * params["penalty_weights"][i]
acc_focal_loss += curr_loss

return acc_focal_loss
21 changes: 14 additions & 7 deletions GANDLF/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def get_class_imbalance_weights_classification(training_df, params):
for i in range(params["model"]["num_classes"]):
penalty_dict[i] /= penalty_sum

return penalty_dict, weight_dict
# passing None for sampling_weights because there is no clear way to calculate this for classification tasks which do not have a label
return penalty_dict, None, weight_dict


def get_class_imbalance_weights_segmentation(training_data_loader, parameters):
Expand Down Expand Up @@ -349,7 +350,7 @@ def get_class_imbalance_weights_segmentation(training_data_loader, parameters):
for key, val in penalty.items()
}

return penalty_dict, weights_dict
return penalty_dict, penalty_dict, weights_dict


def get_class_imbalance_weights(training_df, params):
Expand All @@ -363,10 +364,11 @@ def get_class_imbalance_weights(training_df, params):
Returns:
float, float: The penalty and class weights for different classes under consideration for classification.
"""
penalty_weights, class_weights = None, None
penalty_weights, sampling_weights, class_weights = None, None, None
if params["weighted_loss"] or params["patch_sampler"]["biased_sampling"]:
(penalty_weights, class_weights) = (
params.get("weights", None),
(penalty_weights, sampling_weights, class_weights) = (
params.get("penalty_weights", None),
params.get("sampling_weights", None),
params.get("class_weights", None),
)
# this default is needed for openfl
Expand All @@ -383,11 +385,15 @@ def get_class_imbalance_weights(training_df, params):
else penalty_weights
)

if penalty_weights is None or class_weights is None:
# calculate the penalty/sampling weights only if one of the following conditions are met
if (params["weighted_loss"] and (penalty_weights is None)) or (
params["patch_sampler"]["biased_sampling"] and (sampling_weights is None)
):
print("Calculating weights")
if params["problem_type"] == "classification":
(
penalty_weights,
sampling_weights,
class_weights,
) = get_class_imbalance_weights_classification(training_df, params)
elif params["problem_type"] == "segmentation":
Expand All @@ -410,13 +416,14 @@ def get_class_imbalance_weights(training_df, params):

(
penalty_weights,
sampling_weights,
class_weights,
) = get_class_imbalance_weights_segmentation(penalty_loader, params)
del penalty_data, penalty_loader
else:
print("Using weights from config file")

return penalty_weights, class_weights
return penalty_weights, sampling_weights, class_weights


def get_linear_interpolation_mode(dimensionality):
Expand Down
4 changes: 2 additions & 2 deletions samples/config_all_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ patch_sampler: uniform
# type: label,
# enable_padding: True,
# padding_mode: symmetric, # for options, see 'mode' in https://numpy.org/doc/stable/reference/generated/numpy.pad.html
# biased_sampling: True, # adds additional sampling probability of labels based on "class_weights"; only gets invoked when using label sampler.
# biased_sampling: True, # adds additional sampling probability of labels based on "sampling_weights" key; only gets invoked when using label sampler. If not present, gets calculated using the same mechanism as weighted_loss
# }
# If enabled, this parameter pads images and labels when label sampler is used
enable_padding: False
Expand Down Expand Up @@ -116,7 +116,7 @@ scheduler:
# use mse_torch for regression/classification problems and dice for segmentation
loss_function: dc
# this parameter weights the loss to handle imbalanced losses better
weighted_loss: True # generates new keys "class_weights" and "weights" that handle the aggregate weights of the class and penalties per label, respectively
weighted_loss: True # generates new keys "class_weights" and "penalty_weights" that handle the aggregate weights of the class and penalties per label, respectively
#loss_function:
# {
# 'mse':{
Expand Down

0 comments on commit 2d5513a

Please sign in to comment.