Skip to content

Commit

Permalink
Merge branch 'master' into patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati authored Jul 26, 2023
2 parents 0e2d72e + d1fc9e2 commit 5dee4ec
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 85 deletions.
28 changes: 12 additions & 16 deletions GANDLF/compute/loss_and_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,22 @@ def get_metric_output(metric_function, predicted, ground_truth, params):

def get_loss_and_metrics(image, ground_truth, predicted, params):
"""
image: torch.Tensor
The input image stack according to requirements
ground_truth : torch.Tensor
The input ground truth for the corresponding image label
predicted : torch.Tensor
The input predicted label for the corresponding image label
params : dict
The parameters passed by the user yaml
Returns
-------
loss : torch.Tensor
The computed loss from the label and the output
metric_output : torch.Tensor
The computed metric from the label and the output
This function computes the loss and metrics for a given image, ground truth and predicted output.
Args:
image (torch.Tensor): The input image stack according to requirements.
ground_truth (torch.Tensor): The input ground truth for the corresponding image label.
predicted (torch.Tensor): The input predicted label for the corresponding image label.
params (dict): The parameters passed by the user yaml.
Returns:
torch.Tensor: The computed loss from the label and the prediction.
dict: The computed metric from the label and the prediction.
"""
# this is currently only happening for mse_torch
if isinstance(params["loss_function"], dict):
# check for mse_torch
loss_function = global_losses_dict["mse"]
loss_function = global_losses_dict[list(params["loss_function"].keys())[0]]
else:
loss_str_lower = params["loss_function"].lower()
if loss_str_lower in global_losses_dict:
Expand Down
12 changes: 10 additions & 2 deletions GANDLF/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""
All the losses are to be called from here
"""
from .segmentation import MCD_loss, MCD_log_loss, MCT_loss, KullbackLeiblerDivergence
from .segmentation import (
MCD_loss,
MCD_log_loss,
MCT_loss,
KullbackLeiblerDivergence,
FocalLoss,
)
from .regression import CE, CEL, MSE_loss, L1_loss
from .hybrid import DCCE, DCCE_Logits
from .hybrid import DCCE, DCCE_Logits, DC_Focal


# global defines for the losses
Expand All @@ -20,4 +26,6 @@
"tversky": MCT_loss,
"kld": KullbackLeiblerDivergence,
"l1": L1_loss,
"focal": FocalLoss,
"dc_focal": DC_Focal,
}
63 changes: 35 additions & 28 deletions GANDLF/losses/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
from .segmentation import MCD_loss
import torch

from .segmentation import MCD_loss, FocalLoss
from .regression import CCE_Generic, CE, CE_Logits


def DCCE(predicted_mask, ground_truth, params):
def DCCE(predicted_mask, ground_truth, params) -> torch.Tensor:
"""
Calculates the Dice-Cross-Entropy loss.
Parameters
----------
predicted_mask : torch.Tensor
Predicted mask
ground_truth : torch.Tensor
Ground truth mask
params : dict
Dictionary of parameters
Returns
-------
torch.Tensor
Calculated loss
Args:
predicted_mask (torch.Tensor): The predicted mask.
ground_truth (torch.Tensor): The ground truth mask.
params (dict): The parameters.
Returns:
torch.Tensor: The calculated loss.
"""
dcce_loss = MCD_loss(predicted_mask, ground_truth, params) + CCE_Generic(
predicted_mask, ground_truth, params, CE
Expand All @@ -30,21 +26,32 @@ def DCCE_Logits(predicted_mask, ground_truth, params):
"""
Calculates the Dice-Cross-Entropy loss using logits.
Parameters
----------
predicted_mask : torch.Tensor
Predicted mask logits
ground_truth : torch.Tensor
Ground truth mask
params : dict
Dictionary of parameters
Returns
-------
torch.Tensor
Calculated loss
Args:
predicted_mask (torch.Tensor): The predicted mask.
ground_truth (torch.Tensor): The ground truth mask.
params (dict): The parameters.
Returns:
torch.Tensor: The calculated loss.
"""
dcce_loss = MCD_loss(predicted_mask, ground_truth, params) + CCE_Generic(
predicted_mask, ground_truth, params, CE_Logits
)
return dcce_loss


def DC_Focal(predicted_mask, ground_truth, params):
"""
Calculates the Dice-Focal loss.
Args:
predicted_mask (torch.Tensor): The predicted mask.
ground_truth (torch.Tensor): The ground truth mask.
params (dict): The parameters.
Returns:
torch.Tensor: The calculated loss.
"""
return MCD_loss(predicted_mask, ground_truth, params) + FocalLoss(
predicted_mask, ground_truth, params
)
80 changes: 64 additions & 16 deletions GANDLF/losses/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,16 @@


# Dice scores and dice losses
def dice(predicted, target):
def dice(predicted, target) -> torch.Tensor:
"""
This function computes a dice score between two tensors
This function computes a dice score between two tensors.
Parameters
----------
predicted : Tensor
predicted value by the network
target : Tensor
Required target label to match the predicted with
Returns
-------
Tensor
Computed Dice Score
Args:
predicted (_type_): Predicted value by the network.
target (_type_): Required target label to match the predicted with
Returns:
torch.Tensor: The computed dice score.
"""
predicted_flat = predicted.flatten()
label_flat = target.flatten()
Expand All @@ -36,7 +30,7 @@ def MCD(predicted, target, num_class, weights=None, ignore_class=None, loss_type
This function computes the mean class dice score between two tensors
Args:
predicted (torch.Tensor): predicted generally by the network
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
num_class (int): Number of classes (including the background class)
weights (list, optional): Dice weights for each class (excluding the background class), defaults to None
Expand Down Expand Up @@ -105,7 +99,7 @@ def tversky_loss(predicted, target, alpha=0.5, beta=0.5):
This function calculates the Tversky loss between two tensors.
Args:
predicted (torch.Tensor): predicted generally by the network
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
alpha (float, optional): Weight of false positives. Defaults to 0.5.
beta (float, optional): Weight of false negatives. Defaults to 0.5.
Expand Down Expand Up @@ -138,7 +132,7 @@ def MCT_loss(predicted, target, params=None):
This function calculates the Multi-Class Tversky loss between two tensors.
Args:
predicted (torch.Tensor): predicted generally by the network
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
params (dict, optional): Additional parameters for computing loss function, including weights for each class
Expand Down Expand Up @@ -175,3 +169,57 @@ def KullbackLeiblerDivergence(mu, logvar, params=None):
"""
loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
return loss.mean()


def FocalLoss(predicted, target, params=None):
"""
This function calculates the Focal loss between two tensors.
Args:
predicted (torch.Tensor): Predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
params (dict, optional): Additional parameters for computing loss function, including weights for each class
Returns:
torch.Tensor: Computed Focal Loss
"""
gamma = 2.0
size_average = True
if isinstance(params["loss_function"], dict):
gamma = params["loss_function"].get("gamma", 2.0)
size_average = params["loss_function"].get("size_average", True)

def _focal_loss(preds, target, gamma, size_average=True):
"""
Internal helper function to calcualte focal loss for a single class.
Args:
preds (torch.Tensor): predicted generally by the network
target (torch.Tensor): Required target label to match the predicted with
gamma (float): The gamma value for focal loss
size_average (bool, optional): Whether to average the loss across the batch. Defaults to True.
Returns:
torch.Tensor: Computed focal loss for a single class.
"""
ce_loss = torch.nn.CrossEntropyLoss(reduce=False)
logpt = ce_loss(preds, target)
pt = torch.exp(-logpt)
loss = ((1 - pt) ** gamma) * logpt
return_loss = loss.sum()
if size_average:
return_loss = loss.mean()
return return_loss

acc_focal_loss = 0
num_classes = predicted.shape[1]

for i in range(num_classes):
curr_loss = _focal_loss(
predicted[:, i, ...], target[:, i, ...], gamma, size_average
)
if params is not None and params.get("weights") is not None:
curr_loss = curr_loss * params["weights"][i]
acc_focal_loss += curr_loss

return acc_focal_loss
6 changes: 5 additions & 1 deletion GANDLF/parseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
"optimizer": "adam", # the optimizer
"patch_sampler": "uniform", # type of sampling strategy
"scheduler": "triangle_modified", # the default scheduler
"loss_function": "dc", # default loss
"clip_mode": None, # default clip mode
}

Expand Down Expand Up @@ -186,6 +185,11 @@ def parseConfig(config_file_path, version_check_flag=True):
params["loss_function"] = {}
params["loss_function"]["mse"] = {}
params["loss_function"]["mse"]["reduction"] = "mean"
elif params["loss_function"] == "focal":
params["loss_function"] = {}
params["loss_function"]["focal"] = {}
params["loss_function"]["focal"]["gamma"] = 2.0
params["loss_function"]["focal"]["size_average"] = True

assert "metrics" in params, "'metrics' needs to be defined in the config file"
if "metrics" in params:
Expand Down
2 changes: 1 addition & 1 deletion docs/customize.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ This file contains mid-level information regarding various parameters that can b
- Defined in the `loss_function` parameter of the model configuration.
- By passing `weighted_loss: True`, the loss function will be weighted by the inverse of the class frequency.
- This parameter controls the function which the model is trained. All options can be found [here](https://github.com/mlcommons/GaNDLF/blob/master/GANDLF/losses/__init__.py). Some examples are:
- Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`)
- Segmentation: dice (`dice` or `dc`), dice and cross entropy (`dcce`), focal loss (`focal`), dice and focal (`dc_focal`)
- Classification/regression: mean squared error (`mse`)
- And many more.

Expand Down
7 changes: 7 additions & 0 deletions samples/config_all_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ scheduler:
# Set which loss function you want to use - options : 'dc' - for dice only, 'dcce' - for sum of dice and CE and you can guess the next (only lower-case please)
# options: dc (dice only), dc_log (-log of dice), ce (), dcce (sum of dice and ce), mse () ...
# mse is the MSE defined by torch and can define a variable 'reduction'; see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
# focal is the focal loss and can define 2 variables: gamma and size_average
# use mse_torch for regression/classification problems and dice for segmentation
loss_function: dc
# this parameter weights the loss to handle imbalanced losses better
Expand All @@ -122,6 +123,12 @@ weighted_loss: True
# 'reduction': 'mean' # see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss for all options
# }
# }
#loss_function:
# {
# 'focal':{
# 'gamma': 1.0
# }
# }
# Which optimizer do you want to use - sgd, asgd, adam, adamw, adamax, sparseadam, rprop, adadelta, adagrad, rmsprop,
# each has their own options and functionalities, which are initialized with defaults, see GANDLF.optimizers.wrap_torch for details
optimizer: adam
Expand Down
57 changes: 36 additions & 21 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,30 +1213,45 @@ def test_train_metrics_regression_rad_2d(device):

def test_train_losses_segmentation_rad_2d(device):
print("23: Starting 2D Rad segmentation tests for losses")
# read and parse csv
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
parameters["modality"] = "rad"
parameters["patch_size"] = patch_size["2D"]
parameters["model"]["dimension"] = 2
parameters["model"]["class_list"] = [0, 255]
# disabling amp because some losses do not support Half, yet
parameters["model"]["amp"] = False
parameters["model"]["num_channels"] = 3
parameters["model"]["architecture"] = "resunet"
parameters["metrics"] = ["dice"]
parameters["model"]["onnx_export"] = False
parameters["model"]["print_summary"] = False
parameters = populate_header_in_parameters(parameters, parameters["headers"])
# loop through selected models and train for single epoch
for loss_type in ["dc", "dc_log", "dcce", "dcce_logits", "tversky"]:
# healper function to read and parse yaml and return parameters
def get_parameters_after_alteration(loss_type: str) -> dict:
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
parameters["loss_function"] = loss_type
file_config_temp = get_temp_config_path()
with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)
# read and parse csv
parameters = parseConfig(file_config_temp, version_check_flag=True)
parameters["nested_training"]["testing"] = -5
parameters["nested_training"]["validation"] = -5
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
parameters["modality"] = "rad"
parameters["patch_size"] = patch_size["2D"]
parameters["model"]["dimension"] = 2
parameters["model"]["class_list"] = [0, 255]
# disabling amp because some losses do not support Half, yet
parameters["model"]["amp"] = False
parameters["model"]["num_channels"] = 3
parameters["model"]["architecture"] = "resunet"
parameters["metrics"] = ["dice"]
parameters["model"]["onnx_export"] = False
parameters["model"]["print_summary"] = False
parameters = populate_header_in_parameters(parameters, parameters["headers"])
return parameters, training_data
# loop through selected models and train for single epoch
for loss_type in [
"dc",
"dc_log",
"dcce",
"dcce_logits",
"tversky",
"focal",
"dc_focal"]:
parameters, training_data = get_parameters_after_alteration(loss_type)
sanitize_outputDir()
TrainingManager(
dataframe=training_data,
Expand Down

0 comments on commit 5dee4ec

Please sign in to comment.