diff --git a/GANDLF/compute/loss_and_metric.py b/GANDLF/compute/loss_and_metric.py index 41b3fc32f..7642dd7cc 100644 --- a/GANDLF/compute/loss_and_metric.py +++ b/GANDLF/compute/loss_and_metric.py @@ -25,26 +25,22 @@ def get_metric_output(metric_function, predicted, ground_truth, params): def get_loss_and_metrics(image, ground_truth, predicted, params): """ - image: torch.Tensor - The input image stack according to requirements - ground_truth : torch.Tensor - The input ground truth for the corresponding image label - predicted : torch.Tensor - The input predicted label for the corresponding image label - params : dict - The parameters passed by the user yaml - - Returns - ------- - loss : torch.Tensor - The computed loss from the label and the output - metric_output : torch.Tensor - The computed metric from the label and the output + This function computes the loss and metrics for a given image, ground truth and predicted output. + + Args: + image (torch.Tensor): The input image stack according to requirements. + ground_truth (torch.Tensor): The input ground truth for the corresponding image label. + predicted (torch.Tensor): The input predicted label for the corresponding image label. + params (dict): The parameters passed by the user yaml. + + Returns: + torch.Tensor: The computed loss from the label and the prediction. + dict: The computed metric from the label and the prediction. """ # this is currently only happening for mse_torch if isinstance(params["loss_function"], dict): # check for mse_torch - loss_function = global_losses_dict["mse"] + loss_function = global_losses_dict[list(params["loss_function"].keys())[0]] else: loss_str_lower = params["loss_function"].lower() if loss_str_lower in global_losses_dict: diff --git a/GANDLF/losses/__init__.py b/GANDLF/losses/__init__.py index 4acc05e4a..4e6d04a86 100644 --- a/GANDLF/losses/__init__.py +++ b/GANDLF/losses/__init__.py @@ -1,9 +1,15 @@ """ All the losses are to be called from here """ -from .segmentation import MCD_loss, MCD_log_loss, MCT_loss, KullbackLeiblerDivergence +from .segmentation import ( + MCD_loss, + MCD_log_loss, + MCT_loss, + KullbackLeiblerDivergence, + FocalLoss, +) from .regression import CE, CEL, MSE_loss, L1_loss -from .hybrid import DCCE, DCCE_Logits +from .hybrid import DCCE, DCCE_Logits, DC_Focal # global defines for the losses @@ -20,4 +26,6 @@ "tversky": MCT_loss, "kld": KullbackLeiblerDivergence, "l1": L1_loss, + "focal": FocalLoss, + "dc_focal": DC_Focal, } diff --git a/GANDLF/losses/hybrid.py b/GANDLF/losses/hybrid.py index 27fbc17fe..b6d53a08f 100644 --- a/GANDLF/losses/hybrid.py +++ b/GANDLF/losses/hybrid.py @@ -1,24 +1,20 @@ -from .segmentation import MCD_loss +import torch + +from .segmentation import MCD_loss, FocalLoss from .regression import CCE_Generic, CE, CE_Logits -def DCCE(predicted_mask, ground_truth, params): +def DCCE(predicted_mask, ground_truth, params) -> torch.Tensor: """ Calculates the Dice-Cross-Entropy loss. - Parameters - ---------- - predicted_mask : torch.Tensor - Predicted mask - ground_truth : torch.Tensor - Ground truth mask - params : dict - Dictionary of parameters - - Returns - ------- - torch.Tensor - Calculated loss + Args: + predicted_mask (torch.Tensor): The predicted mask. + ground_truth (torch.Tensor): The ground truth mask. + params (dict): The parameters. + + Returns: + torch.Tensor: The calculated loss. """ dcce_loss = MCD_loss(predicted_mask, ground_truth, params) + CCE_Generic( predicted_mask, ground_truth, params, CE @@ -30,21 +26,32 @@ def DCCE_Logits(predicted_mask, ground_truth, params): """ Calculates the Dice-Cross-Entropy loss using logits. - Parameters - ---------- - predicted_mask : torch.Tensor - Predicted mask logits - ground_truth : torch.Tensor - Ground truth mask - params : dict - Dictionary of parameters - - Returns - ------- - torch.Tensor - Calculated loss + Args: + predicted_mask (torch.Tensor): The predicted mask. + ground_truth (torch.Tensor): The ground truth mask. + params (dict): The parameters. + + Returns: + torch.Tensor: The calculated loss. """ dcce_loss = MCD_loss(predicted_mask, ground_truth, params) + CCE_Generic( predicted_mask, ground_truth, params, CE_Logits ) return dcce_loss + + +def DC_Focal(predicted_mask, ground_truth, params): + """ + Calculates the Dice-Focal loss. + + Args: + predicted_mask (torch.Tensor): The predicted mask. + ground_truth (torch.Tensor): The ground truth mask. + params (dict): The parameters. + + Returns: + torch.Tensor: The calculated loss. + """ + return MCD_loss(predicted_mask, ground_truth, params) + FocalLoss( + predicted_mask, ground_truth, params + ) diff --git a/GANDLF/losses/segmentation.py b/GANDLF/losses/segmentation.py index fa23290b6..5591e2c1e 100644 --- a/GANDLF/losses/segmentation.py +++ b/GANDLF/losses/segmentation.py @@ -3,22 +3,16 @@ # Dice scores and dice losses -def dice(predicted, target): +def dice(predicted, target) -> torch.Tensor: """ - This function computes a dice score between two tensors + This function computes a dice score between two tensors. - Parameters - ---------- - predicted : Tensor - predicted value by the network - target : Tensor - Required target label to match the predicted with - - Returns - ------- - Tensor - Computed Dice Score + Args: + predicted (_type_): Predicted value by the network. + target (_type_): Required target label to match the predicted with + Returns: + torch.Tensor: The computed dice score. """ predicted_flat = predicted.flatten() label_flat = target.flatten() @@ -36,7 +30,7 @@ def MCD(predicted, target, num_class, weights=None, ignore_class=None, loss_type This function computes the mean class dice score between two tensors Args: - predicted (torch.Tensor): predicted generally by the network + predicted (torch.Tensor): Predicted generally by the network target (torch.Tensor): Required target label to match the predicted with num_class (int): Number of classes (including the background class) weights (list, optional): Dice weights for each class (excluding the background class), defaults to None @@ -105,7 +99,7 @@ def tversky_loss(predicted, target, alpha=0.5, beta=0.5): This function calculates the Tversky loss between two tensors. Args: - predicted (torch.Tensor): predicted generally by the network + predicted (torch.Tensor): Predicted generally by the network target (torch.Tensor): Required target label to match the predicted with alpha (float, optional): Weight of false positives. Defaults to 0.5. beta (float, optional): Weight of false negatives. Defaults to 0.5. @@ -138,7 +132,7 @@ def MCT_loss(predicted, target, params=None): This function calculates the Multi-Class Tversky loss between two tensors. Args: - predicted (torch.Tensor): predicted generally by the network + predicted (torch.Tensor): Predicted generally by the network target (torch.Tensor): Required target label to match the predicted with params (dict, optional): Additional parameters for computing loss function, including weights for each class @@ -175,3 +169,57 @@ def KullbackLeiblerDivergence(mu, logvar, params=None): """ loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1) return loss.mean() + + +def FocalLoss(predicted, target, params=None): + """ + This function calculates the Focal loss between two tensors. + + Args: + predicted (torch.Tensor): Predicted generally by the network + target (torch.Tensor): Required target label to match the predicted with + params (dict, optional): Additional parameters for computing loss function, including weights for each class + + Returns: + torch.Tensor: Computed Focal Loss + """ + gamma = 2.0 + size_average = True + if isinstance(params["loss_function"], dict): + gamma = params["loss_function"].get("gamma", 2.0) + size_average = params["loss_function"].get("size_average", True) + + def _focal_loss(preds, target, gamma, size_average=True): + """ + Internal helper function to calcualte focal loss for a single class. + + Args: + preds (torch.Tensor): predicted generally by the network + target (torch.Tensor): Required target label to match the predicted with + gamma (float): The gamma value for focal loss + size_average (bool, optional): Whether to average the loss across the batch. Defaults to True. + + Returns: + torch.Tensor: Computed focal loss for a single class. + """ + ce_loss = torch.nn.CrossEntropyLoss(reduce=False) + logpt = ce_loss(preds, target) + pt = torch.exp(-logpt) + loss = ((1 - pt) ** gamma) * logpt + return_loss = loss.sum() + if size_average: + return_loss = loss.mean() + return return_loss + + acc_focal_loss = 0 + num_classes = predicted.shape[1] + + for i in range(num_classes): + curr_loss = _focal_loss( + predicted[:, i, ...], target[:, i, ...], gamma, size_average + ) + if params is not None and params.get("weights") is not None: + curr_loss = curr_loss * params["weights"][i] + acc_focal_loss += curr_loss + + return acc_focal_loss diff --git a/GANDLF/parseConfig.py b/GANDLF/parseConfig.py index e79d60241..b62d7b371 100644 --- a/GANDLF/parseConfig.py +++ b/GANDLF/parseConfig.py @@ -41,7 +41,6 @@ "optimizer": "adam", # the optimizer "patch_sampler": "uniform", # type of sampling strategy "scheduler": "triangle_modified", # the default scheduler - "loss_function": "dc", # default loss "clip_mode": None, # default clip mode } @@ -186,6 +185,11 @@ def parseConfig(config_file_path, version_check_flag=True): params["loss_function"] = {} params["loss_function"]["mse"] = {} params["loss_function"]["mse"]["reduction"] = "mean" + elif params["loss_function"] == "focal": + params["loss_function"] = {} + params["loss_function"]["focal"] = {} + params["loss_function"]["focal"]["gamma"] = 2.0 + params["loss_function"]["focal"]["size_average"] = True assert "metrics" in params, "'metrics' needs to be defined in the config file" if "metrics" in params: diff --git a/docs/customize.md b/docs/customize.md index 5a1688d8c..24347f4cf 100644 --- a/docs/customize.md +++ b/docs/customize.md @@ -39,7 +39,7 @@ This file contains mid-level information regarding various parameters that can b - Defined in the `loss_function` parameter of the model configuration. - By passing `weighted_loss: True`, the loss function will be weighted by the inverse of the class frequency. - This parameter controls the function which the model is trained. All options can be found [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/losses/__init__.py). Some examples are: - - Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`) + - Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`), focal loss (`focal`), dice and focal (`dc_focal`) - Classification/regression: mean squared error (`mse`) - And many more. diff --git a/samples/config_all_options.yaml b/samples/config_all_options.yaml index 3ef116137..0c0d405db 100644 --- a/samples/config_all_options.yaml +++ b/samples/config_all_options.yaml @@ -112,6 +112,7 @@ scheduler: # Set which loss function you want to use - options : 'dc' - for dice only, 'dcce' - for sum of dice and CE and you can guess the next (only lower-case please) # options: dc (dice only), dc_log (-log of dice), ce (), dcce (sum of dice and ce), mse () ... # mse is the MSE defined by torch and can define a variable 'reduction'; see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss +# focal is the focal loss and can define 2 variables: gamma and size_average # use mse_torch for regression/classification problems and dice for segmentation loss_function: dc # this parameter weights the loss to handle imbalanced losses better @@ -122,6 +123,12 @@ weighted_loss: True # 'reduction': 'mean' # see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss for all options # } # } +#loss_function: +# { +# 'focal':{ +# 'gamma': 1.0 +# } +# } # Which optimizer do you want to use - sgd, asgd, adam, adamw, adamax, sparseadam, rprop, adadelta, adagrad, rmsprop, # each has their own options and functionalities, which are initialized with defaults, see GANDLF.optimizers.wrap_torch for details optimizer: adam diff --git a/testing/test_full.py b/testing/test_full.py index c50700f93..73ad8af65 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -1213,30 +1213,45 @@ def test_train_metrics_regression_rad_2d(device): def test_train_losses_segmentation_rad_2d(device): print("23: Starting 2D Rad segmentation tests for losses") - # read and parse csv - parameters = parseConfig( - testingDir + "/config_segmentation.yaml", version_check_flag=False - ) - training_data, parameters["headers"] = parseTrainingCSV( - inputDir + "/train_2d_rad_segmentation.csv" - ) - parameters["modality"] = "rad" - parameters["patch_size"] = patch_size["2D"] - parameters["model"]["dimension"] = 2 - parameters["model"]["class_list"] = [0, 255] - # disabling amp because some losses do not support Half, yet - parameters["model"]["amp"] = False - parameters["model"]["num_channels"] = 3 - parameters["model"]["architecture"] = "resunet" - parameters["metrics"] = ["dice"] - parameters["model"]["onnx_export"] = False - parameters["model"]["print_summary"] = False - parameters = populate_header_in_parameters(parameters, parameters["headers"]) - # loop through selected models and train for single epoch - for loss_type in ["dc", "dc_log", "dcce", "dcce_logits", "tversky"]: + # healper function to read and parse yaml and return parameters + def get_parameters_after_alteration(loss_type: str) -> dict: + parameters = parseConfig( + testingDir + "/config_segmentation.yaml", version_check_flag=False + ) parameters["loss_function"] = loss_type + file_config_temp = get_temp_config_path() + with open(file_config_temp, "w") as file: + yaml.dump(parameters, file) + # read and parse csv + parameters = parseConfig(file_config_temp, version_check_flag=True) parameters["nested_training"]["testing"] = -5 parameters["nested_training"]["validation"] = -5 + training_data, parameters["headers"] = parseTrainingCSV( + inputDir + "/train_2d_rad_segmentation.csv" + ) + parameters["modality"] = "rad" + parameters["patch_size"] = patch_size["2D"] + parameters["model"]["dimension"] = 2 + parameters["model"]["class_list"] = [0, 255] + # disabling amp because some losses do not support Half, yet + parameters["model"]["amp"] = False + parameters["model"]["num_channels"] = 3 + parameters["model"]["architecture"] = "resunet" + parameters["metrics"] = ["dice"] + parameters["model"]["onnx_export"] = False + parameters["model"]["print_summary"] = False + parameters = populate_header_in_parameters(parameters, parameters["headers"]) + return parameters, training_data + # loop through selected models and train for single epoch + for loss_type in [ + "dc", + "dc_log", + "dcce", + "dcce_logits", + "tversky", + "focal", + "dc_focal"]: + parameters, training_data = get_parameters_after_alteration(loss_type) sanitize_outputDir() TrainingManager( dataframe=training_data,