Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added focal loss #696

Merged
merged 18 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 12 additions & 16 deletions GANDLF/compute/loss_and_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,22 @@ def get_metric_output(metric_function, predicted, ground_truth, params):

def get_loss_and_metrics(image, ground_truth, predicted, params):
"""
image: torch.Tensor
The input image stack according to requirements
ground_truth : torch.Tensor
The input ground truth for the corresponding image label
predicted : torch.Tensor
The input predicted label for the corresponding image label
params : dict
The parameters passed by the user yaml

Returns
-------
loss : torch.Tensor
The computed loss from the label and the output
metric_output : torch.Tensor
The computed metric from the label and the output
This function computes the loss and metrics for a given image, ground truth and predicted output.

Args:
image (torch.Tensor): The input image stack according to requirements.
ground_truth (torch.Tensor): The input ground truth for the corresponding image label.
predicted (torch.Tensor): The input predicted label for the corresponding image label.
params (dict): The parameters passed by the user yaml.

Returns:
torch.Tensor: The computed loss from the label and the prediction.
dict: The computed metric from the label and the prediction.
"""
# this is currently only happening for mse_torch
if isinstance(params["loss_function"], dict):
# check for mse_torch
loss_function = global_losses_dict["mse"]
loss_function = global_losses_dict[list(params["loss_function"].keys())[0]]
else:
loss_str_lower = params["loss_function"].lower()
if loss_str_lower in global_losses_dict:
Expand Down
12 changes: 10 additions & 2 deletions GANDLF/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""
All the losses are to be called from here
"""
from .segmentation import MCD_loss, MCD_log_loss, MCT_loss, KullbackLeiblerDivergence
from .segmentation import (
MCD_loss,
MCD_log_loss,
MCT_loss,
KullbackLeiblerDivergence,
FocalLoss,
)
from .regression import CE, CEL, MSE_loss, L1_loss
from .hybrid import DCCE, DCCE_Logits
from .hybrid import DCCE, DCCE_Logits, DC_Focal


# global defines for the losses
Expand All @@ -20,4 +26,6 @@
"tversky": MCT_loss,
"kld": KullbackLeiblerDivergence,
"l1": L1_loss,
"focal": FocalLoss,
"dc_focal": DC_Focal,
}
63 changes: 35 additions & 28 deletions GANDLF/losses/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
from .segmentation import MCD_loss
import torch

from .segmentation import MCD_loss, FocalLoss
from .regression import CCE_Generic, CE, CE_Logits


def DCCE(predicted_mask, ground_truth, params):
def DCCE(predicted_mask, ground_truth, params) -> torch.Tensor:
"""
Calculates the Dice-Cross-Entropy loss.

Parameters
----------
predicted_mask : torch.Tensor
Predicted mask
ground_truth : torch.Tensor
Ground truth mask
params : dict
Dictionary of parameters

Returns
-------
torch.Tensor
Calculated loss
Args:
predicted_mask (torch.Tensor): The predicted mask.
ground_truth (torch.Tensor): The ground truth mask.
params (dict): The parameters.

Returns:
torch.Tensor: The calculated loss.
"""
dcce_loss = MCD_loss(predicted_mask, ground_truth, params) + CCE_Generic(
predicted_mask, ground_truth, params, CE
Expand All @@ -30,21 +26,32 @@ def DCCE_Logits(predicted_mask, ground_truth, params):
"""
Calculates the Dice-Cross-Entropy loss using logits.

Parameters
----------
predicted_mask : torch.Tensor
Predicted mask logits
ground_truth : torch.Tensor
Ground truth mask
params : dict
Dictionary of parameters

Returns
-------
torch.Tensor
Calculated loss
Args:
predicted_mask (torch.Tensor): The predicted mask.
ground_truth (torch.Tensor): The ground truth mask.
params (dict): The parameters.

Returns:
torch.Tensor: The calculated loss.
"""
dcce_loss = MCD_loss(predicted_mask, ground_truth, params) + CCE_Generic(
predicted_mask, ground_truth, params, CE_Logits
)
return dcce_loss


def DC_Focal(predicted_mask, ground_truth, params):
"""
Calculates the Dice-Focal loss.

Args:
predicted_mask (torch.Tensor): The predicted mask.
ground_truth (torch.Tensor): The ground truth mask.
params (dict): The parameters.

Returns:
torch.Tensor: The calculated loss.
"""
return MCD_loss(predicted_mask, ground_truth, params) + FocalLoss(
predicted_mask, ground_truth, params
)
80 changes: 64 additions & 16 deletions GANDLF/losses/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,16 @@


# Dice scores and dice losses
def dice(predicted, target):
def dice(predicted, target) -> torch.Tensor:
"""
This function computes a dice score between two tensors
This function computes a dice score between two tensors.

Parameters
----------
predicted : Tensor
predicted value by the network
target : Tensor
Required target label to match the predicted with

Returns
-------
Tensor
Computed Dice Score
Args:
predicted (_type_): Predicted value by the network.
target (_type_): Required target label to match the predicted with

Returns:
torch.Tensor: The computed dice score.
"""
predicted_flat = predicted.flatten()
label_flat = target.flatten()
Expand All @@ -36,7 +30,7 @@ def MCD(predicted, target, num_class, weights=None, ignore_class=None, loss_type
This function computes the mean class dice score between two tensors

Args:
predicted (torch.Tensor): predicted generally by the network
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
num_class (int): Number of classes (including the background class)
weights (list, optional): Dice weights for each class (excluding the background class), defaults to None
Expand Down Expand Up @@ -105,7 +99,7 @@ def tversky_loss(predicted, target, alpha=0.5, beta=0.5):
This function calculates the Tversky loss between two tensors.

Args:
predicted (torch.Tensor): predicted generally by the network
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
alpha (float, optional): Weight of false positives. Defaults to 0.5.
beta (float, optional): Weight of false negatives. Defaults to 0.5.
Expand Down Expand Up @@ -138,7 +132,7 @@ def MCT_loss(predicted, target, params=None):
This function calculates the Multi-Class Tversky loss between two tensors.

Args:
predicted (torch.Tensor): predicted generally by the network
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
params (dict, optional): Additional parameters for computing loss function, including weights for each class

Expand Down Expand Up @@ -175,3 +169,57 @@ def KullbackLeiblerDivergence(mu, logvar, params=None):
"""
loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
return loss.mean()


def FocalLoss(predicted, target, params=None):
"""
This function calculates the Focal loss between two tensors.

Args:
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
params (dict, optional): Additional parameters for computing loss function, including weights for each class

Returns:
torch.Tensor: Computed Focal Loss
"""
gamma = 2.0
size_average = True
if isinstance(params["loss_function"], dict):
gamma = params["loss_function"].get("gamma", 2.0)
size_average = params["loss_function"].get("size_average", True)

def _focal_loss(preds, target, gamma, size_average=True):
"""
Internal helper function to calcualte focal loss for a single class.

Args:
preds (torch.Tensor): predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
gamma (float): The gamma value for focal loss
size_average (bool, optional): Whether to average the loss across the batch. Defaults to True.

Returns:
torch.Tensor: Computed focal loss for a single class.
"""
ce_loss = torch.nn.CrossEntropyLoss(reduce=False)
logpt = ce_loss(preds, target)
pt = torch.exp(-logpt)
loss = ((1 - pt) ** gamma) * logpt
return_loss = loss.sum()
if size_average:
return_loss = loss.mean()
return return_loss

acc_focal_loss = 0
num_classes = predicted.shape[1]

for i in range(num_classes):
curr_loss = _focal_loss(
predicted[:, i, ...], target[:, i, ...], gamma, size_average
)
if params is not None and params.get("weights") is not None:
curr_loss = curr_loss * params["weights"][i]
acc_focal_loss += curr_loss

return acc_focal_loss
6 changes: 5 additions & 1 deletion GANDLF/parseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"optimizer": "adam", # the optimizer
"patch_sampler": "uniform", # type of sampling strategy
"scheduler": "triangle_modified", # the default scheduler
"loss_function": "dc", # default loss
"clip_mode": None, # default clip mode
}

Expand Down Expand Up @@ -186,6 +185,11 @@ def parseConfig(config_file_path, version_check_flag=True):
params["loss_function"] = {}
params["loss_function"]["mse"] = {}
params["loss_function"]["mse"]["reduction"] = "mean"
elif params["loss_function"] == "focal":
params["loss_function"] = {}
params["loss_function"]["focal"] = {}
params["loss_function"]["focal"]["gamma"] = 2.0
params["loss_function"]["focal"]["size_average"] = True

assert "metrics" in params, "'metrics' needs to be defined in the config file"
if "metrics" in params:
Expand Down
2 changes: 1 addition & 1 deletion docs/customize.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ This file contains mid-level information regarding various parameters that can b
- Defined in the `loss_function` parameter of the model configuration.
- By passing `weighted_loss: True`, the loss function will be weighted by the inverse of the class frequency.
- This parameter controls the function which the model is trained. All options can be found [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/losses/__init__.py). Some examples are:
- Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`)
- Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`), focal loss (`focal`), dice and focal (`dc_focal`)
- Classification/regression: mean squared error (`mse`)
- And many more.

Expand Down
7 changes: 7 additions & 0 deletions samples/config_all_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ scheduler:
# Set which loss function you want to use - options : 'dc' - for dice only, 'dcce' - for sum of dice and CE and you can guess the next (only lower-case please)
# options: dc (dice only), dc_log (-log of dice), ce (), dcce (sum of dice and ce), mse () ...
# mse is the MSE defined by torch and can define a variable 'reduction'; see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
# focal is the focal loss and can define 2 variables: gamma and size_average
# use mse_torch for regression/classification problems and dice for segmentation
loss_function: dc
# this parameter weights the loss to handle imbalanced losses better
Expand All @@ -122,6 +123,12 @@ weighted_loss: True
# 'reduction': 'mean' # see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss for all options
# }
# }
#loss_function:
# {
# 'focal':{
# 'gamma': 1.0
# }
# }
# Which optimizer do you want to use - sgd, asgd, adam, adamw, adamax, sparseadam, rprop, adadelta, adagrad, rmsprop,
# each has their own options and functionalities, which are initialized with defaults, see GANDLF.optimizers.wrap_torch for details
optimizer: adam
Expand Down
57 changes: 36 additions & 21 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,30 +1213,45 @@ def test_train_metrics_regression_rad_2d(device):

def test_train_losses_segmentation_rad_2d(device):
print("23: Starting 2D Rad segmentation tests for losses")
# read and parse csv
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
parameters["modality"] = "rad"
parameters["patch_size"] = patch_size["2D"]
parameters["model"]["dimension"] = 2
parameters["model"]["class_list"] = [0, 255]
# disabling amp because some losses do not support Half, yet
parameters["model"]["amp"] = False
parameters["model"]["num_channels"] = 3
parameters["model"]["architecture"] = "resunet"
parameters["metrics"] = ["dice"]
parameters["model"]["onnx_export"] = False
parameters["model"]["print_summary"] = False
parameters = populate_header_in_parameters(parameters, parameters["headers"])
# loop through selected models and train for single epoch
for loss_type in ["dc", "dc_log", "dcce", "dcce_logits", "tversky"]:
# healper function to read and parse yaml and return parameters
def get_parameters_after_alteration(loss_type: str) -> dict:
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
parameters["loss_function"] = loss_type
file_config_temp = get_temp_config_path()
with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)
# read and parse csv
parameters = parseConfig(file_config_temp, version_check_flag=True)
parameters["nested_training"]["testing"] = -5
parameters["nested_training"]["validation"] = -5
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
parameters["modality"] = "rad"
parameters["patch_size"] = patch_size["2D"]
parameters["model"]["dimension"] = 2
parameters["model"]["class_list"] = [0, 255]
# disabling amp because some losses do not support Half, yet
parameters["model"]["amp"] = False
parameters["model"]["num_channels"] = 3
parameters["model"]["architecture"] = "resunet"
parameters["metrics"] = ["dice"]
parameters["model"]["onnx_export"] = False
parameters["model"]["print_summary"] = False
parameters = populate_header_in_parameters(parameters, parameters["headers"])
return parameters, training_data
# loop through selected models and train for single epoch
for loss_type in [
"dc",
"dc_log",
"dcce",
"dcce_logits",
"tversky",
"focal",
"dc_focal"]:
parameters, training_data = get_parameters_after_alteration(loss_type)
sanitize_outputDir()
TrainingManager(
dataframe=training_data,
Expand Down
Loading