diff --git a/GANDLF/anonymize/__init__.py b/GANDLF/anonymize/__init__.py index 167ac5692..e0c65ab90 100644 --- a/GANDLF/anonymize/__init__.py +++ b/GANDLF/anonymize/__init__.py @@ -8,7 +8,7 @@ def run_anonymizer( input_path: str, output_path: str, parameters: Union[str, list, int], modality: str -): +) -> None: """ This function performs anonymization of a single image or a collection of images. @@ -17,9 +17,6 @@ def run_anonymizer( output_path (str): The output file or folder. parameters (Union[str, list, int]): The parameters for anonymization; for DICOM scans, the only optional argument is "delete_private_tags", which defaults to True. output_path (str): The modality type to process. - - Returns: - torch.Tensor: The output image after morphological operations. """ if parameters is None: parameters = {} diff --git a/GANDLF/anonymize/convert_to_nifti.py b/GANDLF/anonymize/convert_to_nifti.py index eee4b6e5e..75e05ddfd 100644 --- a/GANDLF/anonymize/convert_to_nifti.py +++ b/GANDLF/anonymize/convert_to_nifti.py @@ -1,7 +1,7 @@ import SimpleITK as sitk -def convert_to_nifti(input_dicom_directory, output_file): +def convert_to_nifti(input_dicom_directory: str, output_file: str) -> None: """ This function performs NIfTI conversion of a DICOM image series. diff --git a/GANDLF/cli/config_generator.py b/GANDLF/cli/config_generator.py index 36bb73e35..f4033f7e3 100644 --- a/GANDLF/cli/config_generator.py +++ b/GANDLF/cli/config_generator.py @@ -1,11 +1,15 @@ import yaml +from typing import List, Optional, Union from pathlib import Path from copy import deepcopy def generate_new_configs_from_key_and_value( - base_config, key, value, upper_level_key=None -): + base_config: dict, + key: str, + value: Union[str, list, int], + upper_level_key: Optional[str] = None, +) -> List[dict]: """ Generate new configs based on a base config and a strategy. @@ -13,10 +17,10 @@ def generate_new_configs_from_key_and_value( base_config (dict): The base configuration to generate new configs from. key (str): The key to change in the base config. value (Union[str, list, int]): The value to change the key to. - upper_level_key (str, optional): The upper level key to change in the base config; useful for dict. Defaults to None. + upper_level_key (Optional[str]): The upper level key in the base config; useful for dict. Defaults to None. Returns: - list: A list of new configs. + List[dict]: A list of new configs. """ configs_to_return = [] if key == "patch_size": @@ -59,7 +63,7 @@ def generate_new_configs_from_key_and_value( return configs_to_return -def remove_duplicates(configs_list): +def remove_duplicates(configs_list: List[dict]) -> List[dict]: """ Remove duplicate configs from a list of configs. @@ -67,7 +71,7 @@ def remove_duplicates(configs_list): configs_list (list): A list of configs. Returns: - list: A list of configs with duplicates removed. + List[dict]: A list of configs with duplicates removed. """ configs_to_return = [] for config in configs_list: @@ -76,7 +80,9 @@ def remove_duplicates(configs_list): return configs_to_return -def config_generator(base_config_path, strategy_path, output_dir): +def config_generator( + base_config_path: str, strategy_path: str, output_dir: str +) -> None: """ Main function that runs the training and inference. diff --git a/GANDLF/cli/deploy.py b/GANDLF/cli/deploy.py index 242a2cea1..df0b2c6f6 100644 --- a/GANDLF/cli/deploy.py +++ b/GANDLF/cli/deploy.py @@ -1,10 +1,5 @@ -import os -import shutil -import yaml -import docker -import tarfile -import io -import sysconfig +import os, shutil, yaml, docker, tarfile, io, sysconfig +from typing import Optional # import copy @@ -18,30 +13,30 @@ def run_deployment( - mlcubedir, - outputdir, - target, - mlcube_type, - entrypoint_script=None, - configfile=None, - modeldir=None, - requires_gpu=None, -): + mlcubedir: str, + outputdir: str, + target: str, + mlcube_type: str, + entrypoint_script: Optional[str] = None, + configfile: Optional[str] = None, + modeldir: Optional[str] = None, + requires_gpu: Optional[bool] = None, +) -> bool: """ - Run the deployment of the model. + This function runs the deployment of the mlcube. Args: mlcubedir (str): The path to the mlcube directory. outputdir (str): The path to the output directory. - target (str): The target to deploy to. - mlcube_type (str): Either 'model' or 'metrics' - entrypoint_script (str): The path of entrypoint script. Only used for metrics and inference - configfile (str, Optional): The path to the configuration file. Required for models - modeldir (str, Optional): The path to the model directory. Required for models - requires_gpu (str, Optional): Whether the model requires GPU. Required for models + target (str): The deployment target. + mlcube_type (str): The type of mlcube. + entrypoint_script (str, optional): The path of entrypoint script; only used for metrics and inference. Defaults to None. + configfile (str, optional): The path of the configuration file; required for models. Defaults to None. + modeldir (str, optional): The path of the model directory; required for models. Defaults to None. + requires_gpu (bool, optional): Whether the model requires GPU; required for models. Defaults to None. Returns: - bool: True if the deployment was successful, False otherwise. + bool: True if the deployment is successful. """ assert ( target in deploy_targets @@ -86,23 +81,26 @@ def run_deployment( def deploy_docker_mlcube( - mlcubedir, - outputdir, - entrypoint_script=None, - config=None, - modeldir=None, - requires_gpu=None, -): + mlcubedir: str, + outputdir: str, + entrypoint_script: Optional[str] = None, + config: Optional[str] = None, + modeldir: Optional[str] = None, + requires_gpu: Optional[bool] = None, +) -> bool: """ - Deploy the docker mlcube of the model or metrics calculator. + This function deploys the mlcube as a docker container. Args: mlcubedir (str): The path to the mlcube directory. outputdir (str): The path to the output directory. - entrypoint_script (str): The path of entrypoint script. Only used for metrics and inference - config (str, Optional): The path to the configuration file. Required for models - modeldir (str, Optional): The path to the model directory. Required for models - requires_gpu (str, Optional): Whether the model requires GPU. Required for models + entrypoint_script (str, optional): The path of the entrypoint script; only used for metrics and inference. Defaults to None. + config (str, optional): The path of the configuration file; required for models. Defaults to None. + modeldir (str, optional): The path of the model directory; required for models. Defaults to None. + requires_gpu (bool, optional): Whether the model requires GPU; required for models. Defaults to None. + + Returns: + bool: True if the deployment is successful. """ mlcube_config_file = os.path.join(mlcubedir, "mlcube.yaml") assert os.path.exists(mlcubedir) and os.path.exists( @@ -238,13 +236,18 @@ def deploy_docker_mlcube( return True -def get_metrics_mlcube_config(mlcube_config_file, entrypoint_script): +def get_metrics_mlcube_config( + mlcube_config_file: str, entrypoint_script: Optional[str] = None +) -> dict: """ - This function is used to get the metrics from mlcube config file. + This function returns the mlcube config for the metrics. Args: - mlcube_config_file (str): The path of mlcube config file. - entrypoint_script (str): The path of entrypoint script. + mlcube_config_file (str): Path to mlcube config file. + entrypoint_script (str, optional): The path of entrypoint script; only used for metrics. Defaults to None. + + Returns: + dict: The mlcube config for the metrics. """ mlcube_config = None with open(mlcube_config_file, "r") as f: @@ -256,14 +259,19 @@ def get_metrics_mlcube_config(mlcube_config_file, entrypoint_script): return mlcube_config -def get_model_mlcube_config(mlcube_config_file, requires_gpu, entrypoint_script): +def get_model_mlcube_config( + mlcube_config_file: str, requires_gpu: bool, entrypoint_script: Optional[str] = None +) -> dict: """ This function returns the mlcube config for the model. Args: mlcube_config_file (str): Path to mlcube config file. requires_gpu (bool): Whether the model requires GPU. - entrypoint_script (str): The path of entrypoint script. Only used for infer task + entrypoint_script (str, optional): The path of entrypoint script; only used for models. Defaults to None. + + Returns: + dict: The mlcube config for the model. """ mlcube_config = None with open(mlcube_config_file, "r") as f: @@ -333,7 +341,7 @@ def get_model_mlcube_config(mlcube_config_file, requires_gpu, entrypoint_script) # ) -def embed_asset(asset, container, asset_name): +def embed_asset(asset: str, container: object, asset_name: str) -> None: """ This function embeds an asset into a container. diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index 24ba6c144..670541b55 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -1,5 +1,6 @@ import sys import yaml +from typing import Optional from pprint import pprint import pandas as pd from tqdm import tqdm @@ -30,7 +31,9 @@ ) -def generate_metrics_dict(input_csv: str, config: str, outputfile: str = None) -> dict: +def generate_metrics_dict( + input_csv: str, config: str, outputfile: Optional[str] = None +) -> dict: """ This function generates metrics from the input csv and the config. @@ -195,18 +198,18 @@ def __fix_2d_tensor(input_tensor): return input_tensor def __percentile_clip( - input_tensor, - reference_tensor=None, - p_min=0.5, - p_max=99.5, - strictlyPositive=True, + input_tensor: torch.Tensor, + reference_tensor: torch.Tensor = None, + p_min: Optional[float] = 0.5, + p_max: Optional[float] = 99.5, + strictlyPositive: Optional[bool] = True, ): - """Normalizes a tensor based on percentiles. Clips values below and above the percentile. + """ + Normalizes a tensor based on percentiles. Clips values below and above the percentile. Percentiles for normalization can come from another tensor. Args: - input_tensor (torch.Tensor): Tensor to be normalized based on the data from the reference_tensor. - If reference_tensor is None, the percentiles from this tensor will be used. + input_tensor (torch.Tensor): Tensor to be normalized based on the data from the reference_tensor. If reference_tensor is None, the percentiles from this tensor will be used. reference_tensor (torch.Tensor, optional): The tensor used for obtaining the percentiles. p_min (float, optional): Lower end percentile. Defaults to 0.5. p_max (float, optional): Upper end percentile. Defaults to 99.5. @@ -276,24 +279,24 @@ def __percentile_clip( strictlyPositive=True, ) - overall_stats_dict[current_subject_id][ - "ssim" - ] = structural_similarity_index(gt_image_infill, output_infill, mask).item() + overall_stats_dict[current_subject_id]["ssim"] = ( + structural_similarity_index(output_infill, gt_image_infill, mask).item() + ) # ncc metrics compute_ncc = parameters.get("compute_ncc", True) if compute_ncc: overall_stats_dict[current_subject_id]["ncc_mean"] = ncc_mean( - gt_image_infill, output_infill + output_infill, gt_image_infill ) overall_stats_dict[current_subject_id]["ncc_std"] = ncc_std( - gt_image_infill, output_infill + output_infill, gt_image_infill ) overall_stats_dict[current_subject_id]["ncc_max"] = ncc_max( - gt_image_infill, output_infill + output_infill, gt_image_infill ) overall_stats_dict[current_subject_id]["ncc_min"] = ncc_min( - gt_image_infill, output_infill + output_infill, gt_image_infill ) # only voxels that are to be inferred (-> flat array) @@ -302,47 +305,47 @@ def __percentile_clip( output_infill = output_infill[mask] overall_stats_dict[current_subject_id]["mse"] = mean_squared_error( - gt_image_infill, output_infill + output_infill, gt_image_infill ).item() overall_stats_dict[current_subject_id]["msle"] = mean_squared_log_error( - gt_image_infill, output_infill + output_infill, gt_image_infill ).item() overall_stats_dict[current_subject_id]["mae"] = mean_absolute_error( - gt_image_infill, output_infill + output_infill, gt_image_infill ).item() # torchmetrics PSNR using "max" overall_stats_dict[current_subject_id]["psnr"] = peak_signal_noise_ratio( - gt_image_infill, output_infill + output_infill, gt_image_infill ).item() # same as above but with epsilon for robustness - overall_stats_dict[current_subject_id][ - "psnr_eps" - ] = peak_signal_noise_ratio( - gt_image_infill, output_infill, epsilon=sys.float_info.epsilon - ).item() + overall_stats_dict[current_subject_id]["psnr_eps"] = ( + peak_signal_noise_ratio( + output_infill, gt_image_infill, epsilon=sys.float_info.epsilon + ).item() + ) # only use fix data range to [0;1] if the data was normalized before if normalize: # torchmetrics PSNR but with fixed data range of 0 to 1 - overall_stats_dict[current_subject_id][ - "psnr_01" - ] = peak_signal_noise_ratio( - gt_image_infill, output_infill, data_range=(0, 1) - ).item() + overall_stats_dict[current_subject_id]["psnr_01"] = ( + peak_signal_noise_ratio( + output_infill, gt_image_infill, data_range=(0, 1) + ).item() + ) # same as above but with epsilon for robustness - overall_stats_dict[current_subject_id][ - "psnr_01_eps" - ] = peak_signal_noise_ratio( - gt_image_infill, - output_infill, - data_range=(0, 1), - epsilon=sys.float_info.epsilon, - ).item() + overall_stats_dict[current_subject_id]["psnr_01_eps"] = ( + peak_signal_noise_ratio( + output_infill, + gt_image_infill, + data_range=(0, 1), + epsilon=sys.float_info.epsilon, + ).item() + ) pprint(overall_stats_dict) if outputfile is not None: diff --git a/GANDLF/cli/main_run.py b/GANDLF/cli/main_run.py index ff9cb27fb..f9676f76b 100644 --- a/GANDLF/cli/main_run.py +++ b/GANDLF/cli/main_run.py @@ -1,4 +1,5 @@ import os, pickle +from typing import Optional from pathlib import Path from GANDLF.training_manager import TrainingManager, TrainingManager_split @@ -13,8 +14,15 @@ def main_run( - data_csv, config_file, model_dir, train_mode, device, resume, reset, output_dir=None -): + data_csv: str, + config_file: str, + model_dir: str, + train_mode: bool, + device: str, + resume: bool, + reset: bool, + output_dir: Optional[str] = None, +) -> None: """ Main function that runs the training and inference. @@ -26,7 +34,7 @@ def main_run( device (str): The device type. resume (bool): Whether the previous run will be resumed or not. reset (bool): Whether the previous run will be reset or not. - output_dir (str): The output directory for the inference session. + output_dir (str): The output directory for the inference session. Defaults to None. Returns: None diff --git a/GANDLF/cli/patch_extraction.py b/GANDLF/cli/patch_extraction.py index 85819ac0e..da0e73f53 100644 --- a/GANDLF/cli/patch_extraction.py +++ b/GANDLF/cli/patch_extraction.py @@ -1,4 +1,5 @@ import os, warnings +from typing import Optional, Union from functools import partial from pathlib import Path @@ -35,16 +36,17 @@ def parse_gandlf_csv(fpath): yield row["SubjectID"], row["Channel_0"], None -def patch_extraction(input_path, output_path, config=None): +def patch_extraction( + input_path: str, output_path: str, config: Optional[Union[str, dict]] = None +) -> None: """ - This function extracts patches from WSIs. + Extract patches from whole slide images. Args: - input_path (str): The input CSV. - config (Union[str, dict, none]): The input yaml config. - output_path (_type_): _description_ + input_path (str): The path to the input CSV file. + output_path (str): The path to the output directory. + config (Optional[Union[str, dict]], optional): The path to the configuration file. Defaults to None. """ - Image.MAX_IMAGE_PIXELS = None warnings.simplefilter("ignore") diff --git a/GANDLF/cli/preprocess_and_save.py b/GANDLF/cli/preprocess_and_save.py index 8a7803fd5..c9144b88b 100644 --- a/GANDLF/cli/preprocess_and_save.py +++ b/GANDLF/cli/preprocess_and_save.py @@ -1,6 +1,6 @@ import os, sys, pickle +from typing import Optional from pathlib import Path -import numpy as np import SimpleITK as sitk from GANDLF.utils import ( @@ -22,9 +22,9 @@ def preprocess_and_save( data_csv: str, config_file: str, output_dir: str, - label_pad_mode: str = "constant", - applyaugs: bool = False, - apply_zero_crop: bool = False, + label_pad_mode: Optional[str] = "constant", + applyaugs: Optional[bool] = False, + apply_zero_crop: Optional[bool] = False, ) -> None: """ This function performs preprocessing based on parameters provided and saves the output. @@ -33,12 +33,9 @@ def preprocess_and_save( data_csv (str): The CSV file of the training data. config_file (str): The YAML file of the training configuration. output_dir (str): The output directory. - label_pad_mode (str): The padding strategy for the label. Defaults to "constant". - applyaugs (bool): If data augmentation is to be applied before saving the image. Defaults to False. - apply_zero_crop (bool): If zero cropping is to be applied before saving the image. Defaults to False. - - Raises: - ValueError: Parameter check from previous + label_pad_mode (Optional[str], optional): The padding mode for the label. Defaults to "constant". + applyaugs (Optional[bool], optional): Whether to apply augmentations. Defaults to False. + apply_zero_crop (Optional[bool], optional): Whether to apply zero crop. Defaults to False. """ Path(output_dir).mkdir(parents=True, exist_ok=True) @@ -51,10 +48,9 @@ def preprocess_and_save( parameter_file = os.path.join(output_dir, "parameters.pkl") if os.path.exists(parameter_file): parameters_prev = pickle.load(open(parameter_file, "rb")) - if parameters != parameters_prev: - raise ValueError( - "The parameters are not the same as the ones stored in the previous run, please re-check." - ) + assert ( + parameters == parameters_prev + ), "The parameters are not the same as the ones stored in the previous run, please re-check." else: with open(parameter_file, "wb") as handle: pickle.dump(parameters, handle, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/GANDLF/cli/recover_config.py b/GANDLF/cli/recover_config.py index 88902f387..84d514eb4 100644 --- a/GANDLF/cli/recover_config.py +++ b/GANDLF/cli/recover_config.py @@ -3,7 +3,17 @@ import os -def recover_config(modelDir, outputFile): +def recover_config(modelDir: str, outputFile: str) -> bool: + """ + This function recovers the configuration file from a model directory. + + Args: + modelDir (str): The model directory with the configuration file and the model. + outputFile (str): The output file for the configuration. + + Returns: + bool: True if the configuration file was successfully recovered. + """ assert os.path.exists( modelDir ), "The model directory does not appear to exist. Please check parameters." diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index 82cebdaa4..b0131c3a5 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -1,10 +1,12 @@ import os import pathlib +from typing import Optional, Tuple import numpy as np import pandas as pd import SimpleITK as sitk import torch +from torch.utils.data import DataLoader import torchio from GANDLF.compute.loss_and_metric import get_loss_and_metrics from GANDLF.compute.step import step @@ -23,29 +25,26 @@ def validate_network( - model, valid_dataloader, scheduler, params, epoch=0, mode="validation" -): + model: torch.nn.Module, + valid_dataloader: DataLoader, + scheduler: object, + params: dict, + epoch: Optional[int] = 0, + mode: Optional[str] = "validation", +) -> Tuple[float, dict]: """ - Function to validate a network for a single epoch - - Parameters - ---------- - model : if parameters["model"]["type"] == torch, this is a torch.model, otherwise this is OV exec_net - The model to process the input image with, it should support appropriate dimensions. - valid_dataloader : torch.DataLoader - The dataloader for the validation epoch - params : dict - The parameters passed by the user yaml - mode: str - The mode of validation, used to write outputs, if requested - - Returns - ------- - average_epoch_valid_loss : float - Validation loss for the current epoch - average_epoch_valid_metric : dict - Validation metrics for the current epoch - + Function to validate a network for a single epoch. + + Args: + model (torch.nn.Module): The model to process the input image with, it should support appropriate dimensions. if parameters["model"]["type"] == torch, this is a torch.model, otherwise this is OV exec_net. + valid_dataloader (DataLoader): The dataloader for the validation epoch. + scheduler (object): The scheduler to use for training. + params (dict): The parameters passed by the user yaml. + epoch (int, optional): The current epoch number. Defaults to 0. + mode (str, optional): The mode of operation. Defaults to "validation". + + Returns: + Tuple[float, dict]: The average validation loss and the average validation metrics. """ print("*" * 20) print("Starting " + mode + " : ") @@ -171,6 +170,10 @@ def validate_network( image = image.unsqueeze(0) image = image.float().to(params["device"]) ## special case for 2D + assert params["model"]["type"] in [ + "torch", + "openvino", + ], "Model type not supported. Please only use 'torch' or 'openvino'." if image.shape[-1] == 1: image = torch.squeeze(image, -1) if params["model"]["type"] == "torch": @@ -181,10 +184,6 @@ def validate_network( inputs={params["model"]["IO"][0][0]: image.cpu().numpy()} )[params["model"]["IO"][1][0]] ) - else: - raise Exception( - "Model type not supported. Please only use 'torch' or 'openvino'." - ) pred_output = pred_output.cpu() / params["q_samples_per_volume"] diff --git a/GANDLF/compute/generic.py b/GANDLF/compute/generic.py index d2210814e..7888a395f 100644 --- a/GANDLF/compute/generic.py +++ b/GANDLF/compute/generic.py @@ -1,4 +1,7 @@ +from typing import Optional, Tuple from pandas.util import hash_pandas_object +import torch +from torch.utils.data import DataLoader from GANDLF.models import get_model from GANDLF.schedulers import get_scheduler @@ -15,23 +18,30 @@ ) -def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu"): +def create_pytorch_objects( + parameters: dict, + train_csv: Optional[str] = None, + val_csv: Optional[str] = None, + device: Optional[str] = "cpu", +) -> Tuple[ + torch.nn.Module, + torch.optim.Optimizer, + DataLoader, + DataLoader, + torch.optim.lr_scheduler.LRScheduler, + dict, +]: """ - This function creates all the PyTorch objects needed for training. + This function creates the PyTorch objects needed for training and validation. Args: - parameters (dict): The parameters dictionary. - train_csv (str): The path to the training CSV file. - val_csv (str): The path to the validation CSV file. - device (str): The device to perform computations on. + parameters (dict): The parameters for the model and training. + train_csv (Optional[str], optional): The path to the training CSV file. Defaults to None. + val_csv (Optional[str], optional): The path to the validation CSV file. Defaults to None. + device (Optional[str], optional): The device to use for training. Defaults to "cpu". Returns: - model (torch.nn.Module): The model to use for training. - optimizer (Optimizer): The optimizer to use for training. - train_loader (torch.utils.data.DataLoader): The training data loader. - val_loader (torch.utils.data.DataLoader): The validation data loader. - scheduler (object): The scheduler to use for training. - parameters (dict): The updated parameters dictionary. + Tuple[ torch.nn.Module, torch.optim.Optimizer, DataLoader, DataLoader, torch.optim.lr_scheduler.LRScheduler, dict, ]: The model, optimizer, train loader, validation loader, scheduler, and parameters. """ # initialize train and val loaders train_loader, val_loader = None, None diff --git a/GANDLF/compute/inference_loop.py b/GANDLF/compute/inference_loop.py index bcab8f6ac..c7113d384 100644 --- a/GANDLF/compute/inference_loop.py +++ b/GANDLF/compute/inference_loop.py @@ -1,7 +1,9 @@ from .forward_pass import validate_network from .generic import create_pytorch_objects import os, sys +from typing import Optional from pathlib import Path +import pandas as pd # hides torchio citation request, see https://github.com/fepegar/torchio/issues/235 os.environ["TORCHIO_HIDE_CITATION_PROMPT"] = "1" @@ -20,24 +22,20 @@ latest_model_path_end, load_ov_model, print_model_summary, + applyCustomColorMap, ) from GANDLF.data.inference_dataloader_histopath import InferTumorSegDataset from GANDLF.data.preprocessing import get_transforms_for_preprocessing -def applyCustomColorMap(im_gray): - img_bgr = cv2.cvtColor(im_gray.astype(np.uint8), cv2.COLOR_BGR2RGB) - lut = np.zeros((256, 1, 3), dtype=np.uint8) - lut[:, 0, 0] = np.zeros((256)).tolist() - lut[:, 0, 1] = np.zeros((256)).tolist() - lut[:, 0, 2] = np.arange(0, 256, 1).tolist() - return cv2.LUT(img_bgr, lut) - - def inference_loop( - inferenceDataFromPickle, device, parameters, modelDir, outputDir=None -): + inferenceDataFromPickle: pd.DataFrame, + device: str, + parameters: dict, + modelDir: str, + outputDir: Optional[str] = None, +) -> None: """ The main training loop. diff --git a/GANDLF/compute/loss_and_metric.py b/GANDLF/compute/loss_and_metric.py index 7642dd7cc..e149c08db 100644 --- a/GANDLF/compute/loss_and_metric.py +++ b/GANDLF/compute/loss_and_metric.py @@ -1,16 +1,32 @@ import sys +from typing import Dict, Tuple from GANDLF.losses import global_losses_dict from GANDLF.metrics import global_metrics_dict +import torch import torch.nn.functional as nnf from GANDLF.utils import one_hot, reverse_one_hot, get_linear_interpolation_mode -def get_metric_output(metric_function, predicted, ground_truth, params): +def get_metric_output( + metric_function: object, + prediction: torch.Tensor, + target: torch.Tensor, + params: dict, +) -> float: """ - This function computes the output of a metric function. + This function computes the metric output for a given metric function, prediction and target. + + Args: + metric_function (object): The metric function to be used. + prediction (torch.Tensor): The input prediction label for the corresponding image label. + target (torch.Tensor): The input ground truth for the corresponding image label. + params (dict): The parameters passed by the user yaml. + + Returns: + float: The computed metric from the label and the prediction. """ - metric_output = metric_function(predicted, ground_truth, params).detach().cpu() + metric_output = metric_function(prediction, target, params).detach().cpu() if metric_output.dim() == 0: return metric_output.item() @@ -23,19 +39,20 @@ def get_metric_output(metric_function, predicted, ground_truth, params): return metric_output.item() -def get_loss_and_metrics(image, ground_truth, predicted, params): +def get_loss_and_metrics( + image: torch.Tensor, target: torch.Tensor, prediction: torch.Tensor, params: dict +) -> Tuple[torch.Tensor, Dict[str, float]]: """ - This function computes the loss and metrics for a given image, ground truth and predicted output. + This function computes the loss and metrics for a given image, ground truth and prediction output. Args: image (torch.Tensor): The input image stack according to requirements. - ground_truth (torch.Tensor): The input ground truth for the corresponding image label. - predicted (torch.Tensor): The input predicted label for the corresponding image label. + target (torch.Tensor): The input ground truth for the corresponding image label. + prediction (torch.Tensor): The input prediction label for the corresponding image label. params (dict): The parameters passed by the user yaml. Returns: - torch.Tensor: The computed loss from the label and the prediction. - dict: The computed metric from the label and the prediction. + Tuple[torch.Tensor, Dict[str,float]]: The computed loss and metrics from the label and the prediction. """ # this is currently only happening for mse_torch if isinstance(params["loss_function"], dict): @@ -53,14 +70,14 @@ def get_loss_and_metrics(image, ground_truth, predicted, params): loss = 0 # specialized loss function for sdnet - sdnet_check = (len(predicted) > 1) and (params["model"]["architecture"] == "sdnet") + sdnet_check = (len(prediction) > 1) and (params["model"]["architecture"] == "sdnet") if params["problem_type"] == "segmentation": - ground_truth = one_hot(ground_truth, params["model"]["class_list"]) + target = one_hot(target, params["model"]["class_list"]) deep_supervision_model = False if ( - (len(predicted) > 1) + (len(prediction) > 1) and not (sdnet_check) and ("deep" in params["model"]["architecture"]) ): @@ -69,17 +86,17 @@ def get_loss_and_metrics(image, ground_truth, predicted, params): # these weights are taken from previous publication (https://arxiv.org/pdf/2103.03759.pdf) loss_weights = [0.5, 0.25, 0.175, 0.075] - assert len(predicted) == len( + assert len(prediction) == len( loss_weights ), "Loss weights must be same length as number of outputs." ground_truth_resampled = [] - ground_truth_prev = ground_truth.detach() - for i, _ in enumerate(predicted): - if ground_truth_prev[0].shape != predicted[i][0].shape: + ground_truth_prev = target.detach() + for i, _ in enumerate(prediction): + if ground_truth_prev[0].shape != prediction[i][0].shape: # we get the expected shape of resampled ground truth expected_shape = reverse_one_hot( - predicted[i][0].detach(), params["model"]["class_list"] + prediction[i][0].detach(), params["model"]["class_list"] ).shape # linear interpolation is needed because we want "soft" images for resampled ground truth @@ -93,22 +110,22 @@ def get_loss_and_metrics(image, ground_truth, predicted, params): if sdnet_check: # this is specific for sdnet-style archs - loss_seg = loss_function(predicted[0], ground_truth.squeeze(-1), params) - loss_reco = global_losses_dict["l1"](predicted[1], image[:, :1, ...], None) - loss_kld = global_losses_dict["kld"](predicted[2], predicted[3]) - loss_cycle = global_losses_dict["mse"](predicted[2], predicted[4], None) + loss_seg = loss_function(prediction[0], target.squeeze(-1), params) + loss_reco = global_losses_dict["l1"](prediction[1], image[:, :1, ...], None) + loss_kld = global_losses_dict["kld"](prediction[2], prediction[3]) + loss_cycle = global_losses_dict["mse"](prediction[2], prediction[4], None) loss = 0.01 * loss_kld + loss_reco + 10 * loss_seg + loss_cycle else: if deep_supervision_model: # this is for models that have deep-supervision - for i, _ in enumerate(predicted): + for i, _ in enumerate(prediction): # loss is calculated based on resampled "soft" labels using a pre-defined weights array loss += ( - loss_function(predicted[i], ground_truth_resampled[i], params) + loss_function(prediction[i], ground_truth_resampled[i], params) * loss_weights[i] ) else: - loss = loss_function(predicted, ground_truth, params) + loss = loss_function(prediction, target, params) metric_output = {} # Metrics should be a list @@ -119,20 +136,20 @@ def get_loss_and_metrics(image, ground_truth, predicted, params): metric_function = global_metrics_dict[metric_lower] if sdnet_check: metric_output[metric] = get_metric_output( - metric_function, predicted[0], ground_truth.squeeze(-1), params + metric_function, prediction[0], target.squeeze(-1), params ) else: if deep_supervision_model: - for i, _ in enumerate(predicted): + for i, _ in enumerate(prediction): metric_output[metric] += get_metric_output( metric_function, - predicted[i], + prediction[i], ground_truth_resampled[i], params, ) else: metric_output[metric] = get_metric_output( - metric_function, predicted, ground_truth, params + metric_function, prediction, target, params ) return loss, metric_output diff --git a/GANDLF/compute/step.py b/GANDLF/compute/step.py index fc506390d..c36258c47 100644 --- a/GANDLF/compute/step.py +++ b/GANDLF/compute/step.py @@ -1,32 +1,28 @@ +from typing import Optional, Tuple import torch import psutil from .loss_and_metric import get_loss_and_metrics -def step(model, image, label, params, train=True): +def step( + model: torch.nn.Module, + image: torch.Tensor, + label: torch.Tensor, + params: dict, + train: Optional[bool] = True, +) -> Tuple[float, dict, torch.Tensor, torch.Tensor]: """ - Function that steps the model for a single batch + This function performs a single step of training or validation. - Parameters - ---------- - model : torch.model - The model to process the input image with, it should support appropriate dimensions. - image : torch.Tensor - The input image stack according to requirements - label : torch.Tensor - The input label for the corresponding image label - params : dict - The parameters passed by the user yaml - - Returns - ------- - loss : torch.Tensor - The computed loss from the label and the output - metric_output : torch.Tensor - The computed metric from the label and the output - output: torch.Tensor - The final output of the model + Args: + model (torch.nn.Module): The model to process the input image with, it should support appropriate dimensions. + image (torch.Tensor): The input image stack according to requirements. + label (torch.Tensor): The input label for the corresponding image tensor. + params (dict): The parameters dictionary. + train (Optional[bool], optional): Whether the step is for training or validation. Defaults to True. + Returns: + Tuple[float, dict, torch.Tensor, torch.Tensor]: The loss, metrics, output, and attention map. """ if params["verbose"]: if torch.cuda.is_available(): diff --git a/GANDLF/compute/training_loop.py b/GANDLF/compute/training_loop.py index 078cf86f7..d129757e3 100644 --- a/GANDLF/compute/training_loop.py +++ b/GANDLF/compute/training_loop.py @@ -1,5 +1,8 @@ import os, time, psutil +from typing import Tuple +import pandas as pd import torch +from torch.utils.data import DataLoader from tqdm import tqdm import numpy as np import torchio @@ -33,28 +36,23 @@ os.environ["TORCHIO_HIDE_CITATION_PROMPT"] = "1" -def train_network(model, train_dataloader, optimizer, params): +def train_network( + model: torch.nn.Module, + train_dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + params: dict, +) -> Tuple[float, dict]: """ - Function to train a network for a single epoch - - Parameters - ---------- - model : torch.model - The model to process the input image with, it should support appropriate dimensions. - train_dataloader : torch.DataLoader - The dataloader for the training epoch - optimizer : torch.optim - Optimizer for optimizing network - params : dict - the parameters passed by the user yaml - - Returns - ------- - average_epoch_train_loss : float - Train loss for the current epoch - average_epoch_train_metric : dict - Train metrics for the current epoch + This function performs the training of the network. + Args: + model (torch.nn.Module): The model to process the input image with, it should support appropriate dimensions. + train_dataloader (DataLoader): The dataloader for the training epoch. + optimizer (torch.optim.Optimizer): Optimizer for optimizing network. + params (dict): The parameters dictionary. + + Returns: + Tuple[float, dict]: The average epoch training loss and metrics. """ print("*" * 20) print("Starting Training : ") @@ -217,24 +215,24 @@ def train_network(model, train_dataloader, optimizer, params): def training_loop( - training_data, - validation_data, - device, - params, - output_dir, - testing_data=None, - epochs=None, -): + training_data: pd.DataFrame, + validation_data: pd.DataFrame, + device: str, + params: dict, + output_dir: str, + testing_data: bool = None, + epochs: bool = None, +) -> None: """ The main training loop. Args: - training_data (pandas.DataFrame): The data to use for training. - validation_data (pandas.DataFrame): The data to use for validation. + training_data (pd.DataFrame): The data to use for training. + validation_data (pd.DataFrame): The data to use for validation. device (str): The device to perform computations on. params (dict): The parameters dictionary. output_dir (str): The output directory. - testing_data (pandas.DataFrame): The data to use for testing. + testing_data (pd.DataFrame): The data to use for testing. epochs (int): The number of epochs to train; if None, take from params. """ # Some autodetermined factors @@ -560,7 +558,7 @@ def training_loop( if __name__ == "__main__": - import argparse, pickle, pandas + import argparse, pickle torch.multiprocessing.freeze_support() # parse the cli arguments here @@ -572,7 +570,11 @@ def training_loop( "-val_loader_pickle", type=str, help="Validation loader pickle", required=True ) parser.add_argument( - "-testing_loader_pickle", type=str, help="Testing loader pickle", required=True + "-testing_loader_pickle", + type=str, + help="Testing loader pickle", + required=False, + default=None, ) parser.add_argument( "-parameter_pickle", type=str, help="Parameters pickle", required=True @@ -584,13 +586,10 @@ def training_loop( # # write parameters to pickle - this should not change for the different folds, so keeping is independent parameters = pickle.load(open(args.parameter_pickle, "rb")) - trainingDataFromPickle = pandas.read_pickle(args.train_loader_pickle) - validationDataFromPickle = pandas.read_pickle(args.val_loader_pickle) + trainingDataFromPickle = pd.read_pickle(args.train_loader_pickle) + validationDataFromPickle = pd.read_pickle(args.val_loader_pickle) testingData_str = args.testing_loader_pickle - if testingData_str == "None": - testingDataFromPickle = None - else: - testingDataFromPickle = pandas.read_pickle(testingData_str) + testingDataFromPickle = pd.read_pickle(testingData_str) if testingData_str else None training_loop( training_data=trainingDataFromPickle, diff --git a/GANDLF/config_manager.py b/GANDLF/config_manager.py index ce402a173..80e147064 100644 --- a/GANDLF/config_manager.py +++ b/GANDLF/config_manager.py @@ -1,3 +1,4 @@ +from typing import Optional, Union import sys, yaml, ast, pkg_resources import numpy as np from copy import deepcopy @@ -43,18 +44,23 @@ } -def initialize_parameter(params, parameter_to_initialize, value=None, evaluate=True): +def initialize_parameter( + params: dict, + parameter_to_initialize: str, + value: Optional[Union[str, list, int, dict]] = None, + evaluate: Optional[bool] = True, +) -> dict: """ - Initializes the specified parameter with supplied value + This function will initialize the parameter in the parameters dict to the value if it is absent. Args: params (dict): The parameter dictionary. parameter_to_initialize (str): The parameter to initialize. - value ((Union[str, list, int]), optional): The value to initialize. Defaults to None. - evaluate (bool, optional): String evaluate. Defaults to True. + value (Optional[Union[str, list, int, dict]], optional): The value to initialize. Defaults to None. + evaluate (Optional[bool], optional): Whether to evaluate the value. Defaults to True. Returns: - [type]: [description] + dict: The parameter dictionary. """ if parameter_to_initialize in params: if evaluate: @@ -72,17 +78,19 @@ def initialize_parameter(params, parameter_to_initialize, value=None, evaluate=T return params -def initialize_key(parameters, key, value=None): +def initialize_key( + parameters: dict, key: str, value: Optional[Union[str, float, list, dict]] = None +) -> dict: """ - This function will initialize the key in the parameters dict to 'None' if it is absent or length is zero. + This function initializes a key in the parameters dictionary to a value if it is absent. Args: parameters (dict): The parameter dictionary. - key (str): The parameter to initialize. - value (n.a.): The value to initialize. + key (str): The key to initialize. + value (Optional[Union[str, float, list, dict]], optional): The value to initialize. Defaults to None. Returns: - dict: The final parameter dictionary. + dict: The parameter dictionary. """ if parameters is None: parameters = {} @@ -98,7 +106,9 @@ def initialize_key(parameters, key, value=None): return parameters -def _parseConfig(config_file_path, version_check_flag=True): +def _parseConfig( + config_file_path: Union[str, dict], version_check_flag: bool = True +) -> None: """ This function parses the configuration file and returns a dictionary of parameters. @@ -370,9 +380,11 @@ def _parseConfig(config_file_path, version_check_flag=True): default_range = ( [-0.1, 0.1] if augmentation_type == "hed_transform" - else [-0.03, 0.03] - if augmentation_type == "hed_transform_light" - else [-0.95, 0.95] + else ( + [-0.03, 0.03] + if augmentation_type == "hed_transform_light" + else [-0.95, 0.95] + ) ) for key in ranges: @@ -716,7 +728,9 @@ def _parseConfig(config_file_path, version_check_flag=True): return params -def ConfigManager(config_file_path, version_check_flag=True) -> None: +def ConfigManager( + config_file_path: Union[str, dict], version_check_flag: bool = True +) -> None: """ This function parses the configuration file and returns a dictionary of parameters. diff --git a/GANDLF/data/ImagesFromDataFrame.py b/GANDLF/data/ImagesFromDataFrame.py index 70790f1fb..39bc9f8cb 100644 --- a/GANDLF/data/ImagesFromDataFrame.py +++ b/GANDLF/data/ImagesFromDataFrame.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Optional, Union import os from pathlib import Path import numpy as np @@ -37,8 +37,8 @@ def ImagesFromDataFrame( dataframe: pandas.DataFrame, parameters: dict, train: bool, - apply_zero_crop: bool = False, - loader_type: str = "", + apply_zero_crop: Optional[bool] = False, + loader_type: Optional[str] = None, ) -> Union[torchio.SubjectsDataset, torchio.Queue]: """ Reads the pandas dataframe and gives the dataloader to use for training/validation/testing. @@ -47,15 +47,13 @@ def ImagesFromDataFrame( dataframe (pandas.DataFrame): The main input dataframe which is calculated after splitting the data CSV. parameters (dict): The parameters dictionary. train (bool): If the dataloader is for training or not. - apply_zero_crop (bool, optional): If zero crop is to be applied. Defaults to False. - loader_type (str, optional): The type of loader to use for printing. Defaults to "". - - Raises: - ValueError: If the subject cannot be loaded. + apply_zero_crop (Optional[bool], optional): Whether to apply zero crop or not. Defaults to False. + loader_type (Optional[str], optional): The type of loader. Defaults to None. Returns: Union[torchio.SubjectsDataset, torchio.Queue]: The dataloader queue for validation/testing (where patching and data augmentation is not required) or the subjects dataset for training. """ + loader_type = loader_type if loader_type is not None else "" # store in previous variable names patch_size = parameters["patch_size"] headers = parameters["headers"] @@ -256,11 +254,9 @@ def _save_resized_images( # Appending this subject to the list of subjects subjects_list.append(subject) - if subjects_with_error: - raise ValueError( - "The following subjects could not be loaded, please recheck or remove and retry:", - subjects_with_error, - ) + assert ( + subjects_with_error is not None + ), f"The following subjects could not be loaded, please recheck or remove and retry: {subjects_with_error}" transformations_list = [] diff --git a/GANDLF/data/augmentation/blur_enhanced.py b/GANDLF/data/augmentation/blur_enhanced.py index b1b7344e9..67bf5a8c7 100644 --- a/GANDLF/data/augmentation/blur_enhanced.py +++ b/GANDLF/data/augmentation/blur_enhanced.py @@ -48,7 +48,7 @@ def get_params(self, std_ranges: TypeSextetFloat) -> TypeTripletFloat: std = self.sample_uniform_sextet(std_ranges) return std - def calculate_std_ranges(self, image: torch.Tensor) -> tuple: + def calculate_std_ranges(self, image: torch.Tensor) -> Tuple[float, float]: std_ranges = self.std_original if self.std_original is None: # calculate the default std range based on 1.5% of the input image std - https://github.com/mlcommons/GaNDLF/issues/518 diff --git a/GANDLF/data/augmentation/hed_augs.py b/GANDLF/data/augmentation/hed_augs.py index c917e6bb3..46bf3c4bb 100644 --- a/GANDLF/data/augmentation/hed_augs.py +++ b/GANDLF/data/augmentation/hed_augs.py @@ -88,7 +88,7 @@ def __init__( cutoff_range (Union[tuple, None]): Patches with mean value outside the cutoff interval will not be augmented. Values from the [0.0, 1.0] range. The RGB channel values are from the same range. Returns: - ColorAugmenterBase: _description_ + ColorAugmenterBase: The color augmenter object. """ # Initialize base class. diff --git a/GANDLF/data/augmentation/noise_enhanced.py b/GANDLF/data/augmentation/noise_enhanced.py index 41fb6ed74..5cea5a74e 100644 --- a/GANDLF/data/augmentation/noise_enhanced.py +++ b/GANDLF/data/augmentation/noise_enhanced.py @@ -64,7 +64,7 @@ def get_params( seed = self._get_random_seed() return mean, std, seed - def calculate_std_ranges(self, image: torch.Tensor) -> tuple: + def calculate_std_ranges(self, image: torch.Tensor) -> Tuple[float, float]: std_ranges = self.std_original if self.std_original is None: # calculate the default std range based on 1.5% of the input image std - https://github.com/mlcommons/GaNDLF/issues/518 diff --git a/GANDLF/data/augmentation/rgb_augs.py b/GANDLF/data/augmentation/rgb_augs.py index b2b90ac45..3c261f188 100644 --- a/GANDLF/data/augmentation/rgb_augs.py +++ b/GANDLF/data/augmentation/rgb_augs.py @@ -1,5 +1,5 @@ from torchvision.transforms import ColorJitter -from typing import Tuple, Union +from typing import Optional, Tuple, Union from torchio.transforms.augmentation import RandomTransform from torchio.transforms import IntensityTransform from torchio import Subject @@ -38,10 +38,10 @@ class RandomColorJitter(RandomTransform, IntensityTransform): def __init__( self, - brightness: Union[float, Tuple[float, float]] = 0.1, - contrast: Union[float, Tuple[float, float]] = 0, - saturation: Union[float, Tuple[float, float]] = 0, - hue: Union[float, Tuple[float, float]] = 0.2, + brightness: Optional[Union[float, Tuple[float, float]]] = 0.1, + contrast: Optional[Union[float, Tuple[float, float]]] = 0, + saturation: Optional[Union[float, Tuple[float, float]]] = 0, + hue: Optional[Union[float, Tuple[float, float]]] = 0.2, **kwargs ): super().__init__(**kwargs) diff --git a/GANDLF/data/augmentation/rotations.py b/GANDLF/data/augmentation/rotations.py index 1f97503a3..c6a1fae0d 100644 --- a/GANDLF/data/augmentation/rotations.py +++ b/GANDLF/data/augmentation/rotations.py @@ -1,23 +1,20 @@ from functools import partial +from typing import List import torch from torchio.transforms import Lambda -def axis_check(axis): +def axis_check(axis: List[int]) -> List[int]: """ - Check the input axis. + This function checks the axis for rotation. Args: - axis (list): Input axis. - - Raises: - ValueError: If axis is not in [1, 2, 3]. + axis (List[int]): The axes of rotation. Returns: - list: Output affected axes. + List[int]: The affected axes. """ - if isinstance(axis, int): if axis == 0: axis = [1] @@ -27,8 +24,14 @@ def axis_check(axis): for count, _ in enumerate(axis): axis[count] += 1 for sub_ax in axis: - if sub_ax not in [1, 2, 3]: - raise ValueError("Axes must be in [1, 2, 3], but was provided as: ", sub_ax) + assert isinstance( + sub_ax, int + ), f"Axis must be an integer, but was provided as: {sub_ax}" + assert sub_ax in [ + 1, + 2, + 3, + ], f"Axes must be in [1, 2, 3], but was provided as: {sub_ax}" relevant_axes = set([1, 2, 3]) if relevant_axes == set(axis): @@ -41,16 +44,13 @@ def axis_check(axis): return affected_axes -def tensor_rotate_90(input_image, axis): +def tensor_rotate_90(input_image: torch.Tensor, axis: List[int]) -> torch.Tensor: """ This function rotates an image by 90 degrees around the specified axis. Args: input_image (torch.Tensor): The input tensor. - axis (list): The axes of rotation. - - Raises: - ValueError: If axis is not in [1, 2, 3]. + axis (List[int]): The axes of rotation. Returns: torch.Tensor: The rotated tensor. @@ -64,16 +64,13 @@ def tensor_rotate_90(input_image, axis): ) -def tensor_rotate_180(input_image, axis): +def tensor_rotate_180(input_image: torch.Tensor, axis: List[int]) -> torch.Tensor: """ This function rotates an image by 180 degrees around the specified axis. Args: input_image (torch.Tensor): The input tensor. - axis (list): The axes of rotation. - - Raises: - ValueError: If axis is not in [1, 2, 3]. + axis (List[int]): The axes of rotation. Returns: torch.Tensor: The rotated tensor. @@ -84,14 +81,32 @@ def tensor_rotate_180(input_image, axis): return input_image.flip(affected_axes[0]).flip(affected_axes[1]) -def rotate_90(parameters): +def rotate_90(parameters: dict) -> Lambda: + """ + This function rotates an image by 90 degrees around the specified axis. + + Args: + parameters (dict): The parameters for the rotation. + + Returns: + Lambda: The rotation function. + """ return Lambda( function=partial(tensor_rotate_90, axis=parameters["axis"]), p=parameters["probability"], ) -def rotate_180(parameters): +def rotate_180(parameters: dict) -> Lambda: + """ + This function rotates an image by 180 degrees around the specified axis. + + Args: + parameters (dict): The parameters for the rotation. + + Returns: + Lambda: The rotation function. + """ return Lambda( function=partial(tensor_rotate_180, axis=parameters["axis"]), p=parameters["probability"], diff --git a/GANDLF/data/inference_dataloader_histopath.py b/GANDLF/data/inference_dataloader_histopath.py index 9c9d7ef55..0abc2f651 100644 --- a/GANDLF/data/inference_dataloader_histopath.py +++ b/GANDLF/data/inference_dataloader_histopath.py @@ -1,13 +1,5 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Fri Mar 8 20:03:35 2019 - -@author: siddhesh -""" - import os - +from typing import Optional import numpy as np import tiffslide from GANDLF.data.patch_miner.opm.utils import get_patch_size_in_microns, tissue_mask @@ -15,16 +7,15 @@ from torch.utils.data.dataset import Dataset -def get_tissue_mask(image): +def get_tissue_mask(image: np.ndarray) -> np.ndarray: """ This function is used to generate tissue masks; works for patches as well Args: - img_rgb (numpy.array): Input image. - rgb_min (int, optional): The minimum threshold. Defaults to 50. + img_rgb (np.ndarray): Input image. Returns: - numpy.array: The tissue mask. + np.ndarray: The tissue mask. """ try: resized_image = resize(image, (512, 512), anti_aliasing=True) @@ -46,7 +37,7 @@ def __init__( stride_size, selected_level, mask_level, - transform=None, + transform: Optional[object] = None, ): self.transform = transform self._wsi_path = wsi_path diff --git a/GANDLF/data/patch_miner/opm/utils.py b/GANDLF/data/patch_miner/opm/utils.py index 53bf3605b..1bee9b1f1 100644 --- a/GANDLF/data/patch_miner/opm/utils.py +++ b/GANDLF/data/patch_miner/opm/utils.py @@ -1,4 +1,5 @@ import sys, os +from typing import List, Optional, Tuple from pathlib import Path import numpy as np import skimage.io @@ -38,7 +39,16 @@ LAB_L_THRESHOLD = 0.80 -def print_sorted_dict(dictionary): +def print_sorted_dict(dictionary: dict) -> str: + """ + Print a dictionary with sorted keys. + + Args: + dictionary (dict): The input dictionary. + + Returns: + str: The sorted dictionary. + """ sorted_keys = sorted(list(dictionary.keys())) output_str = "{" for index, key in enumerate(sorted_keys): @@ -50,7 +60,22 @@ def print_sorted_dict(dictionary): return output_str -def convert_to_tiff(filename, output_dir, img_type="converted"): +def convert_to_tiff( + filename: str, + output_dir: str, + updated_file_name_identifier: Optional[str] = "converted", +) -> str: + """ + Convert an image to tiff. + + Args: + filename (str): The input filename. + output_dir (str): The output directory. + updated_file_name_identifier (str, optional): The identifier to use for the updated file name. Defaults to "converted". + + Returns: + str: The path to the converted image. + """ base, ext = os.path.splitext(filename) # for png or jpg images, write image back to tiff if ext in [".png", ".jpg", ".jpeg"]: @@ -58,7 +83,7 @@ def convert_to_tiff(filename, output_dir, img_type="converted"): Path(converted_img_path).mkdir(parents=True, exist_ok=True) temp_file = os.path.join( converted_img_path, - os.path.basename(base) + "_" + img_type + ".tiff", + os.path.basename(base) + "_" + updated_file_name_identifier + ".tiff", ) temp_img = skimage.io.imread(filename) skimage.io.imsave(temp_file, temp_img) @@ -67,31 +92,34 @@ def convert_to_tiff(filename, output_dir, img_type="converted"): return filename -def pass_method(*args): - """ - Method which takes any number of arguments and returns and empty string. Like 'pass' reserved word, but as a func. - @param args: Any number of arguments. - @return: An empty string. - """ +def pass_method(*args: object) -> str: return "" -def get_nonzero_percent(image): +def get_nonzero_percent(image: np.ndarray) -> float: """ - Return what percentage of image is non-zero. Useful for finding percentage of labels for binary classification. - @param image: label map patch. - @return: fraction of image that is not zero. + Get the percentage of non-zero pixels in an image. + + Args: + image (np.ndarray): The input image. + + Returns: + float: The percentage of non-zero pixels. """ np_img = np.asarray(image) non_zero = np.count_nonzero(np_img) return non_zero / (np_img.shape[0] * np_img.shape[1]) -def get_patch_class_proportions(image): +def get_patch_class_proportions(image: np.ndarray) -> dict: """ - Return what percentage of image is non-zero. Useful for finding percentage of labels for binary classification. - @param image: label map patch. - @return: fraction of image that is not zero. + Get the class proportions of a patch. + + Args: + image (np.ndarray): The input image. + + Returns: + dict: The class proportions """ np_img = np.asarray(image) unique, counts = np.unique(image, return_counts=True) @@ -100,12 +128,16 @@ def get_patch_class_proportions(image): return print_sorted_dict(prop_dict) -def map_values(image, dictionary): +def map_values(image: np.ndarray, dictionary: dict) -> np.ndarray: """ - Modify image by swapping dictionary keys to dictionary values. - @param image: Numpy ndarray of an image (usually label map patch). - @param dictionary: dict(int => int). Keys in image are swapped to corresponding values. - @return: + Map values in an image to a new set of values. + + Args: + image (np.ndarray): The input image. + dictionary (dict): The dictionary to use for mapping. + + Returns: + np.ndarray: The mapped image. """ template = image.copy() # Copy image so all values not in dict are unmodified for key, value in dictionary.items(): @@ -122,7 +154,21 @@ def map_values(image, dictionary): # plt.show() -def hue_range_mask(image, min_hue, max_hue, sat_min=0.05): +def hue_range_mask( + image: np.ndarray, min_hue: float, max_hue: float, sat_min: Optional[float] = 0.05 +) -> np.ndarray: + """ + Mask based on hue range. + + Args: + image (np.ndarray): RGB numpy image + min_hue (float): Minimum hue value + max_hue (float): Maximum hue value + sat_min (Optional[float], optional): Minimum saturation value. Defaults to 0.05. + + Returns: + np.ndarray: image mask, True pixels are within the hue range. + """ hsv_image = rgb2hsv(image) h_channel = gaussian(hsv_image[:, :, HSV_HUE_CHANNEL]) above_min = h_channel > min_hue @@ -133,10 +179,15 @@ def hue_range_mask(image, min_hue, max_hue, sat_min=0.05): return np.logical_and(np.logical_and(above_min, below_max), above_sat) -def tissue_mask(image): +def tissue_mask(image: np.ndarray) -> np.ndarray: """ - Quick and dirty hue range mask for OPM. Works well on H&E. - TODO: Improve this + Mask based on low saturation and value (gray-black colors) + + Args: + image (np.ndarray): RGB numpy image + + Returns: + np.ndarray: image mask, True pixels are gray-black. """ hue_mask = hue_range_mask(image, 0.8, 0.99) final_mask = remove_small_holes(hue_mask) @@ -200,7 +251,18 @@ def tissue_mask(image): # return mask_copy -def patch_size_check(img, patch_height, patch_width): +def patch_size_check(img: np.ndarray, patch_height: int, patch_width: int) -> bool: + """ + This function checks if the patch size is valid. + + Args: + img (np.ndarray): Input image. + patch_height (int): The height of the patch. + patch_width (int): The width of the patch. + + Returns: + bool: Whether or not the patch size is valid. + """ img = np.asarray(img) return_val = False @@ -210,7 +272,7 @@ def patch_size_check(img, patch_height, patch_width): return return_val -def alpha_rgb_2d_channel_check(img): +def alpha_rgb_2d_channel_check(img: np.ndarray) -> bool: """ This function checks if an image has a valid alpha channel. @@ -263,22 +325,24 @@ def alpha_rgb_2d_channel_check(img): def patch_artifact_check( - img, - intensity_thresh=250, - intensity_thresh_saturation=5, - intensity_thresh_b=128, - patch_size=(256, 256), -): + img: np.ndarray, + intensity_thresh: int = 250, + intensity_thresh_saturation: int = 5, + intensity_thresh_b: int = 128, + patch_size: Optional[List[int]] = [256, 256], +) -> bool: """ - This function is used to curate patches from the input image. It is used to remove patches that are mostly background. + This function is used to curate patches from the input image. It is used to remove patches that have artifacts. + Args: img (np.ndarray): Input Patch Array to check the artifact/background. - intensity_thresh (int, optional): Threshold to check whiteness in the patch. Defaults to 225. - intensity_thresh_saturation (int, optional): Threshold to check saturation in the patch. Defaults to 50. - intensity_thresh_b (int, optional) : Threshold to check blackness in the patch - patch_size (int, optional): Tiling Size of the WSI/patch size. Defaults to 256. patch_size=config["patch_size"] + intensity_thresh (int, optional): Threshold to check whiteness in the patch. Defaults to 250. + intensity_thresh_saturation (int, optional): Threshold to check saturation in the patch. Defaults to 5. + intensity_thresh_b (int, optional): Threshold to check blackness in the patch. Defaults to 128. + patch_size (Optional[List[int]], optional): Tiling Size of the WSI/patch size. Defaults to [256, 256]. + Returns: - bool: Whether the patch is valid (True) or not (False) + bool: Whether the patch is valid or not. """ # patch_size = config["patch_size"] patch_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) @@ -310,11 +374,15 @@ def patch_artifact_check( return True -def parse_config(config_file): +def parse_config(config_file: str) -> dict: """ - Parse config file and return a dictionary of config values. - :param config_file: path to config file - :return: dictionary of config values + Function that parses the config file. + + Args: + config_file (str): The path to the config file. + + Returns: + dict: The parsed config file. """ config = yaml.safe_load(open(config_file, "r")) @@ -331,7 +399,7 @@ def parse_config(config_file): return config -def is_mask_too_big(mask): +def is_mask_too_big(mask: np.ndarray) -> bool: """ Function that returns a boolean value indicating whether the mask is too big to make processing slow. @@ -348,12 +416,16 @@ def is_mask_too_big(mask): return False -def generate_initial_mask(slide_path, scale): +def generate_initial_mask(slide_path: str, scale: int) -> Tuple[np.ndarray, tuple]: """ - Helper method to generate random coordinates within a slide - :param slide_path: Path to slide (str) - :param num_patches: Number of patches you want to generate - :return: list of n (x,y) coordinates + Function that generates the initial mask for the slide. + + Args: + slide_path (str): The path to the slide. + scale (int): The scale to use for the mask. + + Returns: + Tuple[np.ndarray, tuple]: The valid mask and the real scale. """ # Open slide and get properties slide = tiffslide.open_slide(slide_path) @@ -378,25 +450,27 @@ def generate_initial_mask(slide_path, scale): return valid_mask, real_scale -def get_patch_size_in_microns(input_slide_path, patch_size_from_config, verbose=False): +def get_patch_size_in_microns( + input_slide_path: str, patch_size_from_config: str, verbose: Optional[bool] = False +) -> List[int]: """ - This function takes a slide path and a patch size in microns and returns the patch size in pixels. + Function that returns the patch size in pixels. Args: - input_slide_path (str): The input WSI path. - patch_size_from_config (str): The patch size in microns. - verbose (bool): Whether to provide verbose prints. - - Raises: - ValueError: If the patch size is not a valid number in microns. + input_slide_path (str): The path to the slide. + patch_size_from_config (str): The patch size from the config file. + verbose (Optional[bool], optional): Whether to print verbose output. Defaults to False. Returns: - list: The patch size in pixels. + List[int]: The patch size after getting converted to pixels. """ - return_patch_size = [0, 0] patch_size = None + assert isinstance( + patch_size_from_config, (str, list, tuple) + ), "Patch size must be a list or string." + if isinstance(patch_size_from_config, str): # first remove all spaces and square brackets patch_size_from_config = patch_size_from_config.replace(" ", "") @@ -410,16 +484,13 @@ def get_patch_size_in_microns(input_slide_path, patch_size_from_config, verbose= patch_size = patch_size_from_config.split("X") if len(patch_size) == 1: patch_size = patch_size_from_config.split("*") - if len(patch_size) == 1: - raise ValueError( - "Could not parse patch size from config.yml, use either ',', 'x', 'X', or '*' as separator between x and y dimensions." - ) + assert ( + len(patch_size) == 2 + ), "Could not parse patch size from config.yml, use either ',', 'x', 'X', or '*' as separator between x and y dimensions." elif isinstance(patch_size_from_config, list) or isinstance( patch_size_from_config, tuple ): patch_size = patch_size_from_config - else: - raise ValueError("Patch size must be a list or string.") magnification_prev = -1 for i, _ in enumerate(patch_size): diff --git a/GANDLF/data/post_process/morphology.py b/GANDLF/data/post_process/morphology.py index a81bf2c1d..ef124f42b 100644 --- a/GANDLF/data/post_process/morphology.py +++ b/GANDLF/data/post_process/morphology.py @@ -1,3 +1,4 @@ +from typing import Optional import torch import torch.nn.functional as F from skimage.measure import label @@ -6,25 +7,29 @@ from GANDLF.utils.generic import get_array_from_image_or_tensor -def torch_morphological(input_image, kernel_size=1, mode="dilation"): +def torch_morphological( + input_image, kernel_size: Optional[int] = 1, mode: Optional[str] = "dilation" +) -> torch.Tensor: """ - This function enables morphological operations using torch. Adapted from https://github.com/DIVA-DIA/Generating-Synthetic-Handwritten-Historical-Documents/blob/e6a798dc2b374f338804222747c56cb44869af5b/HTR_ctc/utils/auxilary_functions.py#L10. + This function performs morphological operations on the input image. Adapted from https://github.com/DIVA-DIA/Generating-Synthetic-Handwritten-Historical-Documents/blob/e6a798dc2b374f338804222747c56cb44869af5b/HTR_ctc/utils/auxilary_functions.py#L10. Args: - input_image (torch.Tensor): The input image. - kernel_size (list): The size of the window to take a max over. - mode (str): The type of morphological operation to perform. + input_image (_type_): The input image. + kernel_size (Optional[int], optional): The size of the window to take a max over. + mode (Optional[str], optional): The mode of the morphological operation. Returns: torch.Tensor: The output image after morphological operations. """ - + assert mode in ["dilation", "erosion", "closing", "opening"], "Invalid mode." + assert len(input_image.shape) in [ + 4, + 5, + ], "Invalid input shape for morphological operations." if len(input_image.shape) == 4: max_pool = F.max_pool2d elif len(input_image.shape) == 5: max_pool = F.max_pool3d - else: - raise ValueError("Input image has invalid shape for morphological operations.") if mode == "dilation": output_image = max_pool( @@ -52,13 +57,15 @@ def torch_morphological(input_image, kernel_size=1, mode="dilation"): return output_image -def fill_holes(input_image, params=None): +def fill_holes( + input_image: torch.Tensor, params: Optional[dict] = None +) -> torch.Tensor: """ This function fills holes in masks. Args: input_image (torch.Tensor): The input image. - params (dict): The parameters dict; unused. + params (Optional[dict], optional): The parameters dict. Defaults to None. Returns: torch.Tensor: The output image after morphological operations. @@ -71,13 +78,13 @@ def fill_holes(input_image, params=None): return torch.from_numpy(output_array) -def cca(input_image, params=None): +def cca(input_image: torch.Tensor, params: Optional[dict] = None) -> torch.Tensor: """ This function performs connected component analysis on the input image. Args: input_image (torch.Tensor): The input image. - params (dict): The parameters dict; + params (Optional[dict], optional): The parameters dict. Defaults to None. Returns: torch.Tensor: The output image after morphological operations. diff --git a/GANDLF/data/post_process/tensor.py b/GANDLF/data/post_process/tensor.py index 0caa91083..937d9e9f6 100644 --- a/GANDLF/data/post_process/tensor.py +++ b/GANDLF/data/post_process/tensor.py @@ -1,16 +1,18 @@ import numpy as np +import torch from GANDLF.utils.generic import get_array_from_image_or_tensor -def get_mapped_label(input_tensor, params): +def get_mapped_label(input_tensor: torch.Tensor, params:dict) -> np.ndarray: """ - This function maps the input label to the output label. + This function maps the input tensor to the output tensor based on the mapping provided in the params. + Args: - input_tensor (Union[torch.Tensor, sitk.Image]): The input label. - params (dict): The parameters dict. + input_tensor (torch.Tensor): The input tensor. + params (dict): The parameters dictionary. Returns: - np.ndarray: The output image after morphological operations. + np.ndarray: The output tensor. """ input_image_array = get_array_from_image_or_tensor(input_tensor) if "data_postprocessing" not in params: diff --git a/GANDLF/data/preprocessing/crop_zero_planes.py b/GANDLF/data/preprocessing/crop_zero_planes.py index ce22ea8b7..287901261 100644 --- a/GANDLF/data/preprocessing/crop_zero_planes.py +++ b/GANDLF/data/preprocessing/crop_zero_planes.py @@ -1,5 +1,5 @@ +from typing import List, Tuple import numpy as np - import torch import nibabel as nib @@ -7,26 +7,23 @@ # adapted from https://codereview.stackexchange.com/questions/132914/crop-black-border-of-image-using-numpy/132933#132933 -def crop_image_outside_zeros(array, patch_size): +def crop_image_outside_zeros( + array: np.ndarray, patch_size: List[int] +) -> Tuple[np.ndarray, np.ndarray]: """ This function rotates an image by 90 degrees around the specified axis. Args: array (numpy.array): The input array. - patch_size (list): The patch size. - - Raises: - ValueError: Array needs to be 4D. + patch_size (List[int]): The patch size. Returns: - numpy.array: The new corner indeces. - numpy.array: The new cropped array. + Tuple[np.ndarray, np.ndarray]: The new corner indices and the new array. """ dimensions = len(array.shape) - if dimensions != 4: - raise ValueError( - "Array expected to be 4D but got {} dimensions.".format(dimensions) - ) + assert dimensions == 4, "Array expected to be 4D but got {} dimensions.".format( + dimensions + ) # collapse to single channel and get the mask of non-zero voxels mask = array.sum(axis=0) > 0 diff --git a/GANDLF/data/preprocessing/normalize_rgb.py b/GANDLF/data/preprocessing/normalize_rgb.py index 99d84b7a7..163fa9780 100644 --- a/GANDLF/data/preprocessing/normalize_rgb.py +++ b/GANDLF/data/preprocessing/normalize_rgb.py @@ -1,5 +1,5 @@ import torch -from typing import List +from typing import List, Optional from torchio.transforms.intensity_transform import IntensityTransform from torchio.data.subject import Subject @@ -15,7 +15,12 @@ class NormalizeRGB(IntensityTransform): """ - def __init__(self, mean: list = None, std: list = None, **kwargs): + def __init__( + self, + mean: Optional[List[float]] = None, + std: Optional[List[float]] = None, + **kwargs, + ): super().__init__(**kwargs) self.mean, self.std = mean, std self.args_names = "mean", "std" diff --git a/GANDLF/data/preprocessing/resample_minimum.py b/GANDLF/data/preprocessing/resample_minimum.py index 0b2ee29b0..7f36d1063 100644 --- a/GANDLF/data/preprocessing/resample_minimum.py +++ b/GANDLF/data/preprocessing/resample_minimum.py @@ -1,3 +1,4 @@ +from typing import Optional import numpy as np import SimpleITK as sitk @@ -10,10 +11,10 @@ class Resample_Minimum(Resample): This performs resampling of an image to the minimum spacing specified by a single number. Otherwise, it will perform standard resampling. Args: - Resample (_type_): _description_ + Resample (SpatialTransform): The parent class for resampling. """ - def __init__(self, target: float = 1, **kwargs): + def __init__(self, target: Optional[float] = 1, **kwargs): super().__init__(**kwargs) @staticmethod diff --git a/GANDLF/data/preprocessing/rgb_conversion.py b/GANDLF/data/preprocessing/rgb_conversion.py index fd92875f3..0371a5f99 100644 --- a/GANDLF/data/preprocessing/rgb_conversion.py +++ b/GANDLF/data/preprocessing/rgb_conversion.py @@ -1,3 +1,4 @@ +from typing import Optional import torch import PIL.Image @@ -59,9 +60,27 @@ def apply_transform(self, subject: Subject) -> Subject: return subject -def rgba2rgb_transform(parameters=None): +def rgba2rgb_transform(parameters: Optional[dict] = None) -> RGBA2RGB: + """ + This function returns the transform to convert RGBA to RGB. + + Args: + parameters (dict, optional): The parameters for the transform. Defaults to None. + + Returns: + RGBA2RGB: The transform to convert RGBA to RGB. + """ return RGBA2RGB() -def rgb2rgba_transform(parameters=None): +def rgb2rgba_transform(parameters: Optional[dict] = None) -> RGB2RGBA: + """ + This function returns the transform to convert RGB to RGBA. + + Args: + parameters (dict, optional): The parameters for the transform. Defaults to None. + + Returns: + RGB2RGBA: The transform to convert RGB to RGBA. + """ return RGB2RGBA() diff --git a/GANDLF/data/preprocessing/template_matching/histogram_matching.py b/GANDLF/data/preprocessing/template_matching/histogram_matching.py index cb352b48a..3922c2f5e 100644 --- a/GANDLF/data/preprocessing/template_matching/histogram_matching.py +++ b/GANDLF/data/preprocessing/template_matching/histogram_matching.py @@ -45,12 +45,15 @@ def apply_normalize(self, image: ScalarImage) -> None: image.from_sitk(normalized_img) -def histogram_matching(parameters): +def histogram_matching(parameters: dict) -> HistogramMatching: """ - This function is a wrapper for histogram matching. + This function performs histogram matching. Args: - parameters (dict): Dictionary of parameters. + parameters (dict): The parameters for the histogram matching. + + Returns: + HistogramMatching: The histogram matching object. """ num_hist_level = parameters.get("num_hist_level", 1024) num_match_points = parameters.get("num_match_points", 16) diff --git a/GANDLF/data/preprocessing/template_matching/utils.py b/GANDLF/data/preprocessing/template_matching/utils.py index 4b804fb97..0d44d6a90 100644 --- a/GANDLF/data/preprocessing/template_matching/utils.py +++ b/GANDLF/data/preprocessing/template_matching/utils.py @@ -1,31 +1,24 @@ """ adapted from https://github.com/TissueImageAnalytics/tiatoolbox/blob/master/tiatoolbox/tools/stainextract.py """ + +from typing import Optional import numpy as np from skimage import exposure import cv2 -def contrast_enhancer(img, low_p=2, high_p=98): +def contrast_enhancer( + img: np.ndarray, low_p: Optional[int] = 2, high_p: Optional[int] = 98 +) -> np.ndarray: """ - Enhancing contrast of the input image using intensity adjustment. - This method uses both image low and high percentiles. + Enhance contrast of an image using percentile rescaling. Args: - img (:class:`numpy.ndarray`): input image used to obtain tissue mask. - Image should be uint8. - low_p (scalar): low percentile of image values to be saturated to 0. - high_p (scalar): high percentile of image values to be saturated to 255. - high_p should always be greater than low_p. + img (np.ndarray): The input image. + low_p (Optional[int], optional): The low percentile. Defaults to 2. + high_p (Optional[int], optional): The high percentile. Defaults to 98. Returns: - img (:class:`numpy.ndarray`): Image (uint8) with contrast enhanced. - - Raises: - AssertionError: Internal errors due to invalid img type. - - Examples: - >>> from tiatoolbox import utils - >>> img = utils.misc.contrast_enhancer(img, low_p=2, high_p=98) - + np.ndarray: The contrast enhanced image. """ # check if image is not uint8 assert img.dtype == np.uint8, "Image should be uint8" @@ -40,20 +33,16 @@ def contrast_enhancer(img, low_p=2, high_p=98): return np.uint8(img_out) -def get_luminosity_tissue_mask(img, threshold): - """Get tissue mask based on the luminosity of the input image. +def get_luminosity_tissue_mask(img: np.ndarray, threshold: float) -> np.ndarray: + """ + Compute tissue mask based on luminosity thresholding. Args: - img (:class:`numpy.ndarray`): input image used to obtain tissue mask. - threshold (float): luminosity threshold used to determine tissue area. + img (np.ndarray): The input image. + threshold (float): The threshold for luminosity. Returns: - tissue_mask (:class:`numpy.ndarray`): binary tissue mask. - - Examples: - >>> from tiatoolbox import utils - >>> tissue_mask = utils.misc.get_luminosity_tissue_mask(img, threshold=0.8) - + np.ndarray: The tissue mask. """ img = img.astype("uint8") # ensure input image is uint8 img = contrast_enhancer(img, low_p=2, high_p=98) # Contrast enhancement @@ -67,59 +56,44 @@ def get_luminosity_tissue_mask(img, threshold): return tissue_mask -def rgb2od(img): - """Convert from RGB to optical density (OD_RGB) space. - RGB = 255 * exp(-1*OD_RGB). +def rgb2od(img: np.ndarray) -> np.ndarray: + """ + Convert from RGB to optical density (OD_RGB). Args: - img (:class:`numpy.ndarray` of type :class:`numpy.uint8`): Image RGB + img (np.ndarray): The input image. Returns: - :class:`numpy.ndarray`: Optical density RGB image. - - Examples: - >>> from tiatoolbox.utils import transforms, misc - >>> rgb_img = misc.imread('path/to/image') - >>> od_img = transforms.rgb2od(rgb_img) - + np.ndarray: The optical density RGB image. """ mask = img == 0 img[mask] = 1 return np.maximum(-1 * np.log(img / 255), 1e-6) -def od2rgb(od): - """Convert from optical density (OD_RGB) to RGB. - RGB = 255 * exp(-1*OD_RGB) +def od2rgb(od: np.ndarray) -> np.ndarray: + """ + Convert from optical density to RGB. Args: - od (:class:`numpy.ndarray`): Optical density RGB image + od (np.ndarray): The optical density image. Returns: - numpy.ndarray: Image RGB - - Examples: - >>> from tiatoolbox.utils import transforms, misc - >>> rgb_img = misc.imread('path/to/image') - >>> od_img = transforms.rgb2od(rgb_img) - >>> rgb_img = transforms.od2rgb(od_img) - + np.ndarray: The RGB image. """ od = np.maximum(od, 1e-6) return (255 * np.exp(-1 * od)).astype(np.uint8) -def dl_output_for_h_and_e(dictionary): - """Return correct value for H and E from dictionary learning output. +def dl_output_for_h_and_e(dictionary: np.ndarray) -> np.ndarray: + """ + Rearrange dictionary for H&E in correct order with H as first output. Args: - dictionary (:class:`numpy.ndarray`): - :class:`sklearn.decomposition.DictionaryLearning` output + dictionary (np.ndarray): The input dictionary. Returns: - :class:`numpy.ndarray`: - With correct values for H and E. - + np.ndarray: The dictionary in the correct order. """ return_dictionary = dictionary if dictionary[0, 0] < dictionary[1, 0]: @@ -128,19 +102,16 @@ def dl_output_for_h_and_e(dictionary): return return_dictionary -def h_and_e_in_right_order(v1, v2): - """Rearrange input vectors for H&E in correct order with H as first output. +def h_and_e_in_right_order(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: + """ + Rearrange vectors for H&E in correct order with H as first output. Args: - v1 (:class:`numpy.ndarray`): - Input vector for stain extraction. - v2 (:class:`numpy.ndarray`): - Input vector for stain extraction. + v1 (np.ndarray): The first vector for stain extraction. + v2 (np.ndarray): The second vector for stain extraction. Returns: - :class:`numpy.ndarray`: - Input vectors in the correct order. - + np.ndarray: The vectors in the correct order. """ return_arr = np.array([v2, v1]) if v1[0] > v2[0]: @@ -149,17 +120,15 @@ def h_and_e_in_right_order(v1, v2): return return_arr -def vectors_in_correct_direction(e_vectors): - """Points the eigen vectors in the right direction. +def vectors_in_correct_direction(e_vectors: np.ndarray) -> np.ndarray: + """ + Ensure that the vectors are in the correct direction. Args: - e_vectors (:class:`numpy.ndarray`): - Eigen vectors. + e_vectors (np.ndarray): The input vectors. Returns: - :class:`numpy.ndarray`: - Pointing in the correct direction. - + np.ndarray: The vectors in the correct direction. """ if e_vectors[0, 0] < 0: e_vectors[:, 0] *= -1 diff --git a/GANDLF/data/preprocessing/threshold_and_clip.py b/GANDLF/data/preprocessing/threshold_and_clip.py index f96b32c71..848a93008 100644 --- a/GANDLF/data/preprocessing/threshold_and_clip.py +++ b/GANDLF/data/preprocessing/threshold_and_clip.py @@ -1,3 +1,4 @@ +from typing import Optional import torch from torchio.data.subject import Subject @@ -35,7 +36,7 @@ class Threshold(IntensityTransform): """ - def __init__(self, out_min: float = None, out_max: float = None, **kwargs): + def __init__(self, out_min: Optional[float] = None, out_max: Optional[float] = None, **kwargs): super().__init__(**kwargs) self.out_min, self.out_max = out_min, out_max self.args_names = "out_min", "out_max" @@ -56,9 +57,27 @@ def threshold(self, tensor: torch.Tensor) -> torch.Tensor: # the "_transform" functions return lambdas that can be used to wrap into a Compose class -def threshold_transform(parameters): +def threshold_transform(parameters:dict) -> Threshold: + """ + This function returns a lambda function that can be used to wrap into a Compose class. + + Args: + parameters (dict): The parameters dictionary. + + Returns: + Threshold: The transform to threshold the image. + """ return Threshold(out_min=parameters["min"], out_max=parameters["max"]) -def clip_transform(parameters): +def clip_transform(parameters:dict) -> Clamp: + """ + This function returns a lambda function that can be used to wrap into a Compose class. + + Args: + parameters (dict): The parameters dictionary. + + Returns: + Clamp: The transform to clip the image. + """ return Clamp(out_min=parameters["min"], out_max=parameters["max"]) diff --git a/GANDLF/grad_clipping/adaptive_gradient_clipping.py b/GANDLF/grad_clipping/adaptive_gradient_clipping.py index 59f6403de..c2d308cd6 100644 --- a/GANDLF/grad_clipping/adaptive_gradient_clipping.py +++ b/GANDLF/grad_clipping/adaptive_gradient_clipping.py @@ -2,20 +2,20 @@ """ Implementation of Adaptive gradient clipping """ - +from typing import List, Optional import torch -def unitwise_norm(x, norm_type=2.0): +def unitwise_norm(x: torch.Tensor, norm_type: Optional[float] = 2.0) -> torch.Tensor: """ - Computes the norm of a tensor x, where the norm is applied across all dimensions except the first one. + Compute norms of each weight tensor in a model, and return the global norm. Args: - x (torch.Tensor): Input tensor. - norm_type (float): The type of norm to compute (default: 2.0). + x (torch.Tensor): The input tensor. + norm_type (Optional[float], optional): The type of norm to compute. Defaults to 2.0. Returns: - torch.Tensor: The norm of the tensor. + torch.Tensor: The global norm. """ if x.ndim <= 1: return x.norm(norm_type) @@ -28,15 +28,20 @@ def unitwise_norm(x, norm_type=2.0): return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) -def adaptive_gradient_clip_(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): +def adaptive_gradient_clip_( + parameters: List[torch.Tensor], + clip_factor: Optional[float] = 0.01, + eps: Optional[float] = 1e-3, + norm_type: Optional[float] = 2.0, +) -> None: """ Performs adaptive gradient clipping on the parameters of a PyTorch model. Args: - parameters (list of torch.Tensor): The parameters to be clipped. - clip_factor (float): The factor by which to clip the gradients (default: 0.01). - eps (float): A small value added to the norm to avoid division by zero (default: 1e-3). - norm_type (float): The type of norm to compute (default: 2.0). + parameters (List[torch.Tensor]): The parameters to be clipped. + clip_factor (Optional[float], optional): The clipping factor. Defaults to 0.01. + eps (Optional[float], optional): The epsilon value. Defaults to 1e-3. + norm_type (Optional[float], optional): The type of norm to compute. Defaults to 2.0. Adaptive Gradient Clipping Original implementation of Adaptive Gradient Clipping derived from diff --git a/GANDLF/grad_clipping/clip_gradients.py b/GANDLF/grad_clipping/clip_gradients.py index ea1608a79..d10f58899 100644 --- a/GANDLF/grad_clipping/clip_gradients.py +++ b/GANDLF/grad_clipping/clip_gradients.py @@ -2,28 +2,30 @@ """ Implementation of functions to clip gradients """ - +from typing import Optional import torch from GANDLF.grad_clipping.adaptive_gradient_clipping import adaptive_gradient_clip_ def dispatch_clip_grad_( - parameters, value: float, mode: str = "norm", norm_type: float = 2.0 -): + parameters: torch.Tensor, + value: float, + mode: Optional[str] = "norm", + norm_type: Optional[float] = 2.0, +) -> None: """ Dispatches the gradient clipping method to the corresponding function based on the mode. Args: - parameters (Iterable): The model parameters to be clipped. + parameters (torch.Tensor): The model parameters to be clipped. value (float): The clipping value/factor/norm, mode dependent. - mode (str): The clipping mode, one of 'norm', 'value', 'agc' (default: 'norm'). - norm_type (float): The p-norm to use for computing the norm of the gradients (default: 2.0). + mode (Optional[str], optional): The mode of clipping. Defaults to "norm". + norm_type (Optional[float], optional): The type of norm to compute. Defaults to 2.0. """ + assert mode in ["norm", "value", "agc"], f"Unknown clip mode ({mode})." if mode == "norm": torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) elif mode == "value": torch.nn.utils.clip_grad_value_(parameters, value) elif mode == "agc": adaptive_gradient_clip_(parameters, value, norm_type=norm_type) - else: - assert False, f"Unknown clip mode ({mode})." diff --git a/GANDLF/grad_clipping/grad_scaler.py b/GANDLF/grad_clipping/grad_scaler.py index a8a0a7d55..e6acac98f 100644 --- a/GANDLF/grad_clipping/grad_scaler.py +++ b/GANDLF/grad_clipping/grad_scaler.py @@ -1,3 +1,4 @@ +from typing import Optional import torch from GANDLF.grad_clipping.clip_gradients import dispatch_clip_grad_ @@ -9,12 +10,12 @@ def __init__(self): def __call__( self, - loss, - optimizer, - clip_grad=None, - clip_mode="norm", - parameters=None, - create_graph=False, + loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + clip_grad: Optional[float] = None, + clip_mode: Optional[str] = "norm", + parameters: Optional[torch.Tensor] = None, + create_graph: Optional[bool] = False, ): """ Scales the loss and performs backward pass through the computation graph. @@ -22,10 +23,10 @@ def __call__( Args: loss (torch.Tensor): The loss tensor to scale and backpropagate. optimizer (torch.optim.Optimizer): The optimizer to step after backpropagation. - clip_grad (float): The clipping value/factor/norm, mode dependent (default: None). - clip_mode (str): The clipping mode, one of 'norm', 'value', 'agc' (default: 'norm'). - parameters (Iterable): The model parameters to clip (default: None). - create_graph (bool): Whether to create a new graph for backpropagation (default: False). + clip_grad (Optional[float], optional): The clipping value/factor/norm, mode dependent. Defaults to None. + clip_mode (Optional[str], optional): The mode of clipping. Defaults to "norm". + parameters (Optional[torch.Tensor], optional): The model parameters to be clipped. Defaults to None. + create_graph (Optional[bool], optional): Whether to create the graph. Defaults to False. """ self._scaler.scale(loss).backward(create_graph=create_graph) if clip_grad is not None: @@ -54,21 +55,23 @@ def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) -def model_parameters_exclude_head(model, clip_mode=None): +def model_parameters_exclude_head( + model: torch.nn.Module, clip_mode: Optional[str] = None +): """ Returns the parameters of a PyTorch model excluding the last two layers (the head). Args: model (torch.nn.Module): The PyTorch model to get the parameters from. - clip_mode (str): The clipping mode, one of 'norm', 'value', 'agc' (default: None). + clip_mode (Optional[str], optional): The mode of clipping. Defaults to None. Returns: Iterable: The model parameters excluding the last two layers if clip_mode is 'agc', otherwise all parameters. """ exclude_head = False - if clip_mode is not None: - if clip_mode == "agc": - exclude_head = True + clip_mode = str(clip_mode).lower() if clip_mode is not None else None + if clip_mode == "agc": + exclude_head = True if exclude_head: return [p for p in model.parameters()][:-2] else: diff --git a/GANDLF/inference_manager.py b/GANDLF/inference_manager.py index 24db87f35..2d418340e 100644 --- a/GANDLF/inference_manager.py +++ b/GANDLF/inference_manager.py @@ -1,4 +1,5 @@ import os +from typing import Optional from pathlib import Path import pandas as pd import torch @@ -8,19 +9,22 @@ from GANDLF.utils import get_unique_timestamp -def InferenceManager(dataframe, modelDir, parameters, device, outputDir=None): +def InferenceManager( + dataframe: pd.DataFrame, + modelDir: str, + parameters: dict, + device: str, + outputDir: Optional[str] = None, +) -> None: """ - This function takes in a dataframe, with some other parameters and performs the inference on the data in the dataframe. + This function is used to perform inference on a model using a dataframe. Args: - dataframe (pandas.DataFrame): The dataframe containing the data to be used for inference. - modelDir (str): The path to the directory containing the model to be used for inference. - outputDir (str): The path to the directory where the output of the inference will be stored. - parameters (dict): The dictionary containing the parameters for the inference. - device (str): The device type. - - Returns: - None + dataframe (pd.DataFrame): The dataframe containing the data to be used for inference. + modelDir (str): The path to the model directory. + parameters (dict): The parameters to be used for inference. + device (str): The device to be used for inference. + outputDir (Optional[str], optional): The output directory for the inference results. Defaults to None. """ # get the indeces for kfold splitting inferenceData_full = dataframe diff --git a/GANDLF/logger.py b/GANDLF/logger.py index f5c2dbf39..bb3168583 100755 --- a/GANDLF/logger.py +++ b/GANDLF/logger.py @@ -7,24 +7,18 @@ """ import os +from typing import Dict import torch class Logger: - def __init__(self, logger_csv_filename, metrics): + def __init__(self, logger_csv_filename: str, metrics: Dict[str, float]) -> None: """ + Logger class to log the training and validation metrics to a csv file. - Parameters - ---------- - logger_csv_filename : String - Path to a filename where the csv has to be stored - metric : list - Should be a list of the metrics - - Returns - ------- - None. - + Args: + logger_csv_filename (str): Path to a filename where the csv has to be stored. + metrics (Dict[str, float]): The metrics to be logged. """ self.filename = logger_csv_filename self.metrics = metrics @@ -44,22 +38,16 @@ def write_header(self, mode="train"): # print("Found a pre-existing file for logging, now appending logs to that file!") self.csv.close() - def write(self, epoch_number, loss, epoch_metrics): + def write( + self, epoch_number: int, loss: float, epoch_metrics: Dict[str, float] + ) -> None: """ + Write the epoch number, loss and metrics to the csv file. - Parameters - ---------- - epoch_number : TYPE - DESCRIPTION. - loss : TYPE - DESCRIPTION. - metrics : TYPE - DESCRIPTION. - - Returns - ------- - None. - + Args: + epoch_number (int): The epoch number. + loss (float): The loss value. + epoch_metrics (Dict[str, float]): The metrics to be logged. """ self.csv = open(self.filename, "a") row = "" diff --git a/GANDLF/losses/hybrid.py b/GANDLF/losses/hybrid.py index b6d53a08f..ddf62fa01 100644 --- a/GANDLF/losses/hybrid.py +++ b/GANDLF/losses/hybrid.py @@ -4,54 +4,56 @@ from .regression import CCE_Generic, CE, CE_Logits -def DCCE(predicted_mask, ground_truth, params) -> torch.Tensor: +def DCCE(prediction: torch.Tensor, target: torch.Tensor, params: dict) -> torch.Tensor: """ Calculates the Dice-Cross-Entropy loss. Args: - predicted_mask (torch.Tensor): The predicted mask. - ground_truth (torch.Tensor): The ground truth mask. + prediction (torch.Tensor): The predicted mask. + target (torch.Tensor): The ground truth mask. params (dict): The parameters. Returns: torch.Tensor: The calculated loss. """ - dcce_loss = MCD_loss(predicted_mask, ground_truth, params) + CCE_Generic( - predicted_mask, ground_truth, params, CE + dcce_loss = MCD_loss(prediction, target, params) + CCE_Generic( + prediction, target, params, CE ) return dcce_loss -def DCCE_Logits(predicted_mask, ground_truth, params): +def DCCE_Logits( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: """ Calculates the Dice-Cross-Entropy loss using logits. Args: - predicted_mask (torch.Tensor): The predicted mask. - ground_truth (torch.Tensor): The ground truth mask. + prediction (torch.Tensor): The predicted mask. + target (torch.Tensor): The ground truth mask. params (dict): The parameters. Returns: torch.Tensor: The calculated loss. """ - dcce_loss = MCD_loss(predicted_mask, ground_truth, params) + CCE_Generic( - predicted_mask, ground_truth, params, CE_Logits + dcce_loss = MCD_loss(prediction, target, params) + CCE_Generic( + prediction, target, params, CE_Logits ) return dcce_loss -def DC_Focal(predicted_mask, ground_truth, params): +def DC_Focal( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: """ Calculates the Dice-Focal loss. Args: - predicted_mask (torch.Tensor): The predicted mask. - ground_truth (torch.Tensor): The ground truth mask. + prediction (torch.Tensor): The predicted mask. + target (torch.Tensor): The ground truth mask. params (dict): The parameters. Returns: torch.Tensor: The calculated loss. """ - return MCD_loss(predicted_mask, ground_truth, params) + FocalLoss( - predicted_mask, ground_truth, params - ) + return MCD_loss(prediction, target, params) + FocalLoss(prediction, target, params) diff --git a/GANDLF/losses/regression.py b/GANDLF/losses/regression.py index 3e34f49ec..bd7911895 100644 --- a/GANDLF/losses/regression.py +++ b/GANDLF/losses/regression.py @@ -1,10 +1,13 @@ +from typing import Optional import torch import torch.nn.functional as F from torch.nn import MSELoss, CrossEntropyLoss, L1Loss from GANDLF.utils import one_hot -def CEL(prediction, target, params): +def CEL( + prediction: torch.Tensor, target: torch.Tensor, params: dict = None +) -> torch.Tensor: """ Cross entropy loss with optional class weights. @@ -34,7 +37,7 @@ def CEL(prediction, target, params): return cel(prediction, target) -def CE_Logits(prediction, target): +def CE_Logits(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Binary cross entropy loss with logits. @@ -53,7 +56,7 @@ def CE_Logits(prediction, target): return loss_val -def CE(prediction, target): +def CE(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Binary cross entropy loss. @@ -74,7 +77,12 @@ def CE(prediction, target): return loss_val -def CCE_Generic(prediction, target, params, CCE_Type): +def CCE_Generic( + prediction: torch.Tensor, + target: torch.Tensor, + params: dict, + CCE_Type: Optional[torch.nn.modules.loss._Loss] = CrossEntropyLoss, +) -> torch.Tensor: """ Generic function to calculate CCE loss @@ -82,7 +90,7 @@ def CCE_Generic(prediction, target, params, CCE_Type): prediction (torch.tensor): The predicted output value for each pixel. dimension: [batch, class, x, y, z]. target (torch.tensor): The ground truth target for each pixel. dimension: [batch, class, x, y, z] factorial_class_list. params (dict): The parameter dictionary. - CCE_Type (torch.nn): The CE loss function type. + CCE_Type (torch.nn.modules.loss._Loss, optional): The type of cross entropy loss to use. Defaults to CrossEntropyLoss. Returns: torch.tensor: The final loss value after taking multiple classes into consideration @@ -104,23 +112,23 @@ def CCE_Generic(prediction, target, params, CCE_Type): return acc_ce_loss -def L1(prediction, target, reduction="mean", scaling_factor=1): - """ - Calculate the mean absolute error between the output variable from the network and the target - Parameters - ---------- - prediction : torch.Tensor - The prediction generated by the network - target : torch.Tensor - The target for the corresponding Tensor for which the output was generated - reduction : str, optional - The type of reduction to apply to the output. Can be "none", "mean", or "sum". Default is "mean". - scaling_factor : int, optional - The scaling factor to multiply the target with. Default is 1. - Returns - ------- - loss : torch.Tensor - The computed Mean Absolute Error (L1) loss for the output and target +def L1( + prediction: torch.Tensor, + target: torch.Tensor, + reduction: Optional[str] = "mean", + scaling_factor: Optional[float] = 1, +) -> torch.Tensor: + """ + Calculate the mean absolute error between the output variable from the network and the target. + + Args: + prediction (torch.Tensor): The prediction generated usually by the network. + target (torch.Tensor): The target for the corresponding Tensor for which the output was generated. + reduction (Optional[str], optional): The reduction method for the loss. Defaults to 'mean'. + scaling_factor (Optional[float], optional): The scaling factor to multiply the target with. Defaults to 1. + + Returns: + torch.Tensor: The mean absolute error loss. """ scaling_factor = torch.as_tensor( scaling_factor, dtype=target.dtype, device=target.device @@ -130,14 +138,16 @@ def L1(prediction, target, reduction="mean", scaling_factor=1): return loss -def L1_loss(prediction, target, params): +def L1_loss( + prediction: torch.Tensor, target: torch.Tensor, params: Optional[dict] = None +) -> torch.Tensor: """ - Computes the L1 loss between the predictionut tensor and the target tensor. + Computes the L1 loss between the prediction tensor and the target tensor. - Parameters: - prediction (torch.Tensor): The predictionut tensor. + Args: + prediction (torch.Tensor): The prediction tensor. target (torch.Tensor): The target tensor. - params (dict, optional): A dictionary of hyperparameters. Defaults to None. + params (Optional[dict], optional): The dictionary of parameters. Defaults to None. Returns: loss (torch.Tensor): The computed L1 loss. @@ -178,23 +188,23 @@ def L1_loss(prediction, target, params): return acc_mse_loss -def MSE(prediction, target, reduction="mean", scaling_factor=1): - """ - Calculate the mean square error between the output variable from the network and the target - Parameters - ---------- - prediction : torch.Tensor - The prediction generated usually by the network - target : torch.Tensor - The target for the corresponding Tensor for which the output was generated - reduction : string, optional - DESCRIPTION. The default is 'mean'. - scaling_factor : float, optional - The scaling factor to multiply the target with - Returns - ------- - loss : torch.Tensor - Computed Mean Squared Error loss for the output and target +def MSE( + prediction: torch.Tensor, + target: torch.Tensor, + reduction: Optional[str] = "mean", + scaling_factor: Optional[float] = 1, +) -> torch.Tensor: + """ + Compute the mean squared error loss for the prediction and target + + Args: + prediction (torch.Tensor): The prediction generated usually by the network. + target (torch.Tensor): The target for the corresponding Tensor for which the output was generated. + reduction (Optional[str], optional): The reduction method for the loss. Defaults to 'mean'. + scaling_factor (Optional[float], optional): The scaling factor to multiply the target with. Defaults to 1. + + Returns: + torch.Tensor: The computed mean squared error loss. """ scaling_factor = torch.as_tensor(scaling_factor, dtype=torch.float32) target = target.float() * scaling_factor @@ -202,35 +212,26 @@ def MSE(prediction, target, reduction="mean", scaling_factor=1): return loss -def MSE_loss(prediction, target, params=None): +def MSE_loss( + prediction: torch.Tensor, target: torch.Tensor, params: Optional[dict] = None +) -> torch.Tensor: """ - Compute the mean squared error loss for the predictionut and target + Compute the mean squared error loss for the prediction and target. - Parameters - ---------- - prediction : torch.Tensor - The predictionut tensor - target : torch.Tensor - The target tensor - params : dict, optional - A dictionary of parameters. Default: None. - If params is not None and contains the key "loss_function", the value of - "loss_function" is expected to be a dictionary with a key "mse", which - can contain the key "reduction" and/or "scaling_factor". If "reduction" is - not specified, the default is 'mean'. If "scaling_factor" is not specified, - the default is 1. + Args: + prediction (torch.Tensor): The prediction generated usually by the network. + target (torch.Tensor): The target for the corresponding Tensor for which the output was generated. + params (Optional[dict], optional): The dictionary of parameters. Defaults to None. - Returns - ------- - acc_mse_loss : torch.Tensor - Computed mean squared error loss for the predictionut and target + Returns: + torch.Tensor: The computed mean squared error loss. """ reduction = "mean" scaling_factor = 1 if params is not None and "loss_function" in params: mse_params = params["loss_function"].get("mse", {}) - reduction = mse_params.get("reduction", "mean") - scaling_factor = mse_params.get("scaling_factor", 1) + reduction = mse_params.get("reduction", reduction) + scaling_factor = mse_params.get("scaling_factor", scaling_factor) if prediction.shape[0] == 1: acc_mse_loss = MSE( diff --git a/GANDLF/losses/segmentation.py b/GANDLF/losses/segmentation.py index ba8c58903..32e43bc25 100644 --- a/GANDLF/losses/segmentation.py +++ b/GANDLF/losses/segmentation.py @@ -1,4 +1,5 @@ import sys +from typing import List, Optional import torch @@ -61,9 +62,9 @@ def generic_loss_calculator( target: torch.Tensor, num_class: int, loss_criteria, - weights: list = None, - ignore_class: int = None, - loss_type: int = 0, + weights: Optional[List[float]] = None, + ignore_class: Optional[int] = None, + loss_type: Optional[int] = 0, ) -> torch.Tensor: """ This function computes the mean class dice score between two tensors @@ -73,9 +74,9 @@ def generic_loss_calculator( target (torch.Tensor): Required target label to match the predicted with num_class (int): Number of classes (including the background class) loss_criteria (function): Loss function to use - weights (list, optional): Dice weights for each class (excluding the background class), defaults to None - ignore_class (int, optional): Class to ignore, defaults to None - loss_type (int, optional): Type of loss to compute, defaults to 0. The options are: + weights (Optional[List[float]], optional): Dice weights for each class (excluding the background class), defaults to None + ignore_class (Optional[int], optional): Class to ignore, defaults to None + loss_type (Optional[int], optional): Type of loss to compute, defaults to 0. The options are: 0: no loss, normal dice calculation 1: dice loss, (1-dice) 2: log dice, -log(dice) @@ -214,16 +215,19 @@ def MCC_log_loss( def tversky_loss( - predicted: torch.Tensor, target: torch.Tensor, alpha: float = 0.5, beta: float = 0.5 + predicted: torch.Tensor, + target: torch.Tensor, + alpha: Optional[float] = 0.5, + beta: Optional[float] = 0.5, ) -> torch.Tensor: """ This function calculates the Tversky loss between two tensors. Args: - predicted (torch.Tensor): Predicted generally by the network - target (torch.Tensor): Required target label to match the predicted with - alpha (float, optional): Weight of false positives. Defaults to 0.5. - beta (float, optional): Weight of false negatives. Defaults to 0.5. + predicted (torch.Tensor): Predicted generally by the network. + target (torch.Tensor): Required target label to match the predicted with. + alpha (Optional[float], optional): The alpha value for Tversky loss. Defaults to 0.5. + beta (Optional[float], optional): The beta value for Tversky loss. Defaults to 0.5. Returns: torch.Tensor: Computed Tversky Loss @@ -249,14 +253,14 @@ def tversky_loss( def MCT_loss( - predicted: torch.Tensor, target: torch.Tensor, params: dict = None + predicted: torch.Tensor, target: torch.Tensor, params: Optional[dict] = None ) -> torch.Tensor: """ This function calculates the Multi-Class Tversky loss between two tensors. Args: - predicted (torch.Tensor): Predicted generally by the network - target (torch.Tensor): Required target label to match the predicted with + predicted (torch.Tensor): Predicted generally by the network. + target (torch.Tensor): Required target label to match the predicted with. params (dict, optional): Additional parameters for computing loss function, including weights for each class Returns: @@ -278,14 +282,14 @@ def MCT_loss( return acc_tv_loss -def KullbackLeiblerDivergence(mu, logvar, params=None): +def KullbackLeiblerDivergence(mu, logvar, params: Optional[dict] = None): """ Calculates the Kullback-Leibler divergence between two Gaussian distributions. Args: - mu (torch.Tensor): The mean of the first Gaussian distribution - logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution - params (dict, optional): A dictionary of optional parameters + mu (torch.Tensor): The mean of the first Gaussian distribution. + logvar (torch.Tensor): The logarithm of the variance of the first Gaussian distribution. + params (Optional[dict], optional): The dictionary of parameters. Defaults to None. Returns: torch.Tensor: The computed Kullback-Leibler divergence @@ -295,15 +299,15 @@ def KullbackLeiblerDivergence(mu, logvar, params=None): def FocalLoss( - predicted: torch.Tensor, target: torch.Tensor, params: dict = None + predicted: torch.Tensor, target: torch.Tensor, params: Optional[dict] = None ) -> torch.Tensor: """ This function calculates the Focal loss between two tensors. Args: - predicted (torch.Tensor): Predicted generally by the network - target (torch.Tensor): Required target label to match the predicted with - params (dict, optional): Additional parameters for computing loss function, including weights for each class + predicted (torch.Tensor): Predicted generally by the network. + target (torch.Tensor): Required target label to match the predicted with. + params (Optional[dict], optional): Additional parameters for computing loss function, including gamma and size_average. Defaults to None. Returns: torch.Tensor: Computed Focal Loss @@ -314,7 +318,9 @@ def FocalLoss( gamma = params["loss_function"].get("gamma", 2.0) size_average = params["loss_function"].get("size_average", True) - def _focal_loss(preds, target, gamma, size_average=True): + def _focal_loss( + preds, target, gamma, size_average: Optional[bool] = True + ) -> torch.Tensor: """ Internal helper function to calculate focal loss for a single class. diff --git a/GANDLF/metrics/classification.py b/GANDLF/metrics/classification.py index ec4e9d160..f48ca37d6 100644 --- a/GANDLF/metrics/classification.py +++ b/GANDLF/metrics/classification.py @@ -1,16 +1,17 @@ +import torch import torchmetrics as tm from torch.nn.functional import one_hot from ..utils import get_output_from_calculator from GANDLF.utils.generic import determine_classification_task_type -def overall_stats(predictions, ground_truth, params): +def overall_stats(prediction: torch.Tensor, target: torch.Tensor, params: dict) -> dict: """ - Generates a dictionary of metrics calculated on the overall predictions and ground truths. + Generates a dictionary of metrics calculated on the overall prediction and ground truths. Args: - predictions (torch.Tensor): The output of the model. - ground_truth (torch.Tensor): The ground truth labels. + prediction (torch.Tensor): The output of the model. + target (torch.Tensor): The ground truth labels. params (dict): The parameter dictionary containing training and data information. Returns: @@ -70,19 +71,19 @@ def overall_stats(predictions, ground_truth, params): for metric_name, calculator in calculators.items(): if metric_name == "aucroc": one_hot_preds = one_hot( - predictions.long(), + prediction.long(), num_classes=params["model"]["num_classes"], ) output_metrics[metric_name] = get_output_from_calculator( - one_hot_preds.float(), ground_truth, calculator + one_hot_preds.float(), target, calculator ) else: output_metrics[metric_name] = get_output_from_calculator( - predictions, ground_truth, calculator + prediction, target, calculator ) - #### HERE WE NEED TO MODIFY TESTS - ROC IS RETURNING A TUPLE. WE MAY ALSO DISCRAD IT #### - # what is AUC metric telling at all? Computing it for predictions and ground truth + #### HERE WE NEED TO MODIFY TESTS - ROC IS RETURNING A TUPLE. WE MAY ALSO DISCARD IT #### + # what is AUC metric telling at all? Computing it for prediction and ground truth # is not making sense # metrics that do not have any "average" parameter # calculators = { @@ -94,14 +95,14 @@ def overall_stats(predictions, ground_truth, params): # for metric_name, calculator in calculators.items(): # if metric_name == "roc": # one_hot_preds = one_hot( - # predictions.long(), num_classes=params["model"]["num_classes"] + # prediction.long(), num_classes=params["model"]["num_classes"] # ) # output_metrics[metric_name] = get_output_from_calculator( - # one_hot_preds.float(), ground_truth, calculator + # one_hot_preds.float(), target, calculator # ) # else: # output_metrics[metric_name] = get_output_from_calculator( - # predictions, ground_truth, calculator + # prediction, target, calculator # ) return output_metrics diff --git a/GANDLF/metrics/generic.py b/GANDLF/metrics/generic.py index 1c90a3a20..2779de626 100644 --- a/GANDLF/metrics/generic.py +++ b/GANDLF/metrics/generic.py @@ -1,5 +1,6 @@ import torch from torchmetrics import ( + Metric, F1Score, Precision, Recall, @@ -15,12 +16,25 @@ ) -def generic_function_output_with_check(predicted_classes, label, metric_function): - if torch.min(predicted_classes) < 0: +def generic_function_output_with_check( + prediction: torch.Tensor, target: torch.Tensor, metric_function: object +) -> torch.Tensor: + """ + This function computes the output of a generic metric function. + + Args: + prediction (torch.Tensor): The prediction of the model. + target (torch.Tensor): The ground truth labels. + metric_function (object): The metric function to be used, which is a wrapper around the torchmetrics class. + + Returns: + torch.Tensor: The output of the metric function. + """ + if torch.min(prediction) < 0: print( "WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations." ) - return torch.zeros((1), device=predicted_classes.device) + return torch.zeros((1), device=prediction.device) else: # I need to do this with try-except, otherwise for binary problems it will # raise and error as the binary metrics do not have .num_classes @@ -30,19 +44,38 @@ def generic_function_output_with_check(predicted_classes, label, metric_function max_clamp_val = metric_function.num_classes - 1 except AttributeError: max_clamp_val = 1 - predicted_new = torch.clamp(predicted_classes.cpu().int(), max=max_clamp_val) - predicted_new = predicted_new.reshape(label.shape) - return metric_function(predicted_new, label.cpu().int()) - - -def generic_torchmetrics_score(output, label, metric_class, metric_key, params): + predicted_new = torch.clamp(prediction.cpu().int(), max=max_clamp_val) + predicted_new = predicted_new.reshape(target.shape) + return metric_function(predicted_new, target.cpu().int()) + + +def generic_torchmetrics_score( + prediction: torch.Tensor, + target: torch.Tensor, + metric_class: Metric, + metric_key: str, + params: dict, +) -> torch.Tensor: + """ + This function computes the output of a generic torchmetrics metric. + + Args: + prediction (torch.Tensor): The prediction of the model. + target (torch.Tensor): The ground truth labels. + metric_class (Metric): The metric class to be used. + metric_key (str): The key for the metric. + params (dict): The parameter dictionary containing training and data information. + + Returns: + torch.Tensor: The output of the metric function. + """ task = determine_classification_task_type(params) num_classes = params["model"]["num_classes"] - predicted_classes = output + predicted_classes = prediction if params["problem_type"] == "classification": - predicted_classes = torch.argmax(output, 1) + predicted_classes = torch.argmax(prediction, 1) elif params["problem_type"] == "segmentation": - label = one_hot(label, params["model"]["class_list"]) + target = one_hot(target, params["model"]["class_list"]) metric_function = metric_class( task=task, num_classes=num_classes, @@ -52,37 +85,53 @@ def generic_torchmetrics_score(output, label, metric_class, metric_key, params): ) return generic_function_output_with_check( - predicted_classes.cpu().int(), label.cpu().int(), metric_function + predicted_classes.cpu().int(), target.cpu().int(), metric_function ) -def recall_score(output, label, params): - return generic_torchmetrics_score(output, label, Recall, "recall", params) +def recall_score( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: + return generic_torchmetrics_score(prediction, target, Recall, "recall", params) -def precision_score(output, label, params): - return generic_torchmetrics_score(output, label, Precision, "precision", params) +def precision_score( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: + return generic_torchmetrics_score( + prediction, target, Precision, "precision", params + ) -def f1_score(output, label, params): - return generic_torchmetrics_score(output, label, F1Score, "f1", params) +def f1_score( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: + return generic_torchmetrics_score(prediction, target, F1Score, "f1", params) -def accuracy(output, label, params): - return generic_torchmetrics_score(output, label, Accuracy, "accuracy", params) +def accuracy( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: + return generic_torchmetrics_score(prediction, target, Accuracy, "accuracy", params) -def specificity_score(output, label, params): - return generic_torchmetrics_score(output, label, Specificity, "specificity", params) +def specificity_score( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: + return generic_torchmetrics_score( + prediction, target, Specificity, "specificity", params + ) -def iou_score(output, label, params): +def iou_score( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: num_classes = params["model"]["num_classes"] - predicted_classes = output + predicted_classes = prediction if params["problem_type"] == "classification": - predicted_classes = torch.argmax(output, 1) + predicted_classes = torch.argmax(prediction, 1) elif params["problem_type"] == "segmentation": - label = one_hot(label, params["model"]["class_list"]) + target = one_hot(target, params["model"]["class_list"]) task = determine_classification_task_type(params) recall = JaccardIndex( task=task, @@ -92,5 +141,5 @@ def iou_score(output, label, params): ) return generic_function_output_with_check( - predicted_classes.cpu().int(), label.cpu().int(), recall + predicted_classes.cpu().int(), target.cpu().int(), recall ) diff --git a/GANDLF/metrics/regression.py b/GANDLF/metrics/regression.py index 9267ae449..913b37fac 100644 --- a/GANDLF/metrics/regression.py +++ b/GANDLF/metrics/regression.py @@ -1,6 +1,7 @@ """ All the metrics are to be called from here """ + import torch from sklearn.metrics import balanced_accuracy_score import numpy as np @@ -8,56 +9,60 @@ from ..utils import get_output_from_calculator -def classification_accuracy(output, label, params): +def classification_accuracy( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: """ This function computes the classification accuracy. Args: - output (torch.Tensor): The output of the model. - label (torch.Tensor): The ground truth labels. + prediction (torch.Tensor): The prediction of the model. + target (torch.Tensor): The ground truth labels. params (dict): The parameter dictionary containing training and data information. Returns: torch.Tensor: The classification accuracy. """ + predicted_classes = prediction if params["problem_type"] == "classification": - predicted_classes = torch.argmax(output, 1) - else: - predicted_classes = output + predicted_classes = torch.argmax(prediction, 1) - acc = torch.sum(predicted_classes == label.squeeze()) / len(label) + acc = torch.sum(predicted_classes == target.squeeze()) / len(target) return acc -def balanced_acc_score(output, label, params): +def balanced_acc_score( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: """ This function computes the balanced accuracy. Args: - output (torch.Tensor): The output of the model. - label (torch.Tensor): The ground truth labels. + prediction (torch.Tensor): The prediction of the model. + target (torch.Tensor): The ground truth labels. params (dict): The parameter dictionary containing training and data information. Returns: torch.Tensor: The balanced accuracy. """ + predicted_classes = prediction if params["problem_type"] == "classification": - predicted_classes = torch.argmax(output, 1) - else: - predicted_classes = output + predicted_classes = torch.argmax(prediction, 1) return torch.from_numpy( - np.array(balanced_accuracy_score(predicted_classes.cpu(), label.cpu())) + np.array(balanced_accuracy_score(predicted_classes.cpu(), target.cpu())) ) -def per_label_accuracy(output, label, params): +def per_label_accuracy( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: """ This function computes the per class accuracy. Args: - output (torch.Tensor): The output of the model. - label (torch.Tensor): The ground truth labels. + prediction (torch.Tensor): The prediction of the model. + target (torch.Tensor): The ground truth labels. params (dict): The parameter dictionary containing training and data information. Returns: @@ -66,31 +71,31 @@ def per_label_accuracy(output, label, params): if params["problem_type"] == "classification": # ensure this works for multiple batches output_accuracy = torch.zeros(len(params["model"]["class_list"])) - for output_batch, label_batch in zip(output, label): + for output_batch, label_batch in zip(prediction, target): predicted_classes = torch.Tensor([0] * len(params["model"]["class_list"])) label_cpu = torch.Tensor([0] * len(params["model"]["class_list"])) predicted_classes[torch.argmax(output_batch, 0).cpu().item()] = 1 label_cpu[label_batch.cpu().item()] = 1 output_accuracy += (predicted_classes == label_cpu).type(torch.float) - return output_accuracy / len(output) + return output_accuracy / len(prediction) else: - return balanced_acc_score(output, label, params) + return balanced_acc_score(prediction, target, params) -def overall_stats(predictions, ground_truth, params): +def overall_stats(prediction: torch.Tensor, target: torch.Tensor, params: dict) -> dict: """ Generates a dictionary of metrics calculated on the overall predictions and ground truths. Args: - predictions (torch.Tensor): The output of the model. - ground_truth (torch.Tensor): The ground truth labels. + predictions (torch.Tensor): The prediction of the model. + target (torch.Tensor): The ground truth labels. params (dict): The parameter dictionary containing training and data information. Returns: dict: A dictionary of metrics. """ - predictions = predictions.type(torch.float) - ground_truth = ground_truth.type(torch.float) * params["scaling_factor"] + prediction = prediction.type(torch.float) + target = target.type(torch.float) * params["scaling_factor"] assert ( params["problem_type"] == "regression" ), "Only regression is supported for these stats" @@ -108,9 +113,9 @@ def overall_stats(predictions, ground_truth, params): "cosinesimilarity": tm.CosineSimilarity(reduction=reduction_type_key), } for metric_name, calculator in calculators.items(): - output_metrics[ - f"{metric_name}_{reduction_type}" - ] = get_output_from_calculator(predictions, ground_truth, calculator) + output_metrics[f"{metric_name}_{reduction_type}"] = ( + get_output_from_calculator(prediction, target, calculator) + ) # metrics that do not have any "reduction" parameter calculators = { "mse": tm.MeanSquaredError(), @@ -120,7 +125,7 @@ def overall_stats(predictions, ground_truth, params): } for metric_name, calculator in calculators.items(): output_metrics[metric_name] = get_output_from_calculator( - predictions, ground_truth, calculator + prediction, target, calculator ) return output_metrics diff --git a/GANDLF/metrics/segmentation.py b/GANDLF/metrics/segmentation.py index 87b9b9965..52d989dd7 100644 --- a/GANDLF/metrics/segmentation.py +++ b/GANDLF/metrics/segmentation.py @@ -1,6 +1,8 @@ """ All the segmentation metrics are to be called from here """ + +from typing import List, Optional, Tuple, Union import sys import torch import numpy as np @@ -13,7 +15,7 @@ ) -def _convert_tensor_to_int_label_array(input_tensor): +def _convert_tensor_to_int_label_array(input_tensor: torch.Tensor) -> np.ndarray: """ This function converts a tensor of labels to a numpy array of labels. @@ -32,18 +34,23 @@ def _convert_tensor_to_int_label_array(input_tensor): return result_array.astype(np.int64) -def multi_class_dice(output, label, params, per_label=False): +def multi_class_dice( + prediction: torch.Tensor, + target: torch.Tensor, + params: dict, + per_label: Optional[bool] = False, +) -> Union[torch.Tensor, List[float]]: """ This function computes a multi-class dice. Args: - output (torch.Tensor): Input data containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. - label (torch.Tensor): Input data containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + prediction (torch.Tensor): The input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): The input ground truth containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. params (dict): The parameter dictionary containing training and data information. - per_label (bool, optional): Whether the dice needs to be calculated per label or not. Defaults to False. + per_label (Optional[bool], optional): Whether to return per-label scores. Defaults to False. Returns: - float or list: The average dice for all labels or a list of per-label dice scores. + Union[torch.Tensor, List[float]]: The multi-class dice score or the list of per-label dice scores. """ total_dice = 0 avg_counter = 0 @@ -51,7 +58,7 @@ def multi_class_dice(output, label, params, per_label=False): for i in range(0, params["model"]["num_classes"]): # this check should only happen during validation if i != params["model"]["ignore_label_validation"]: - current_dice = dice(output[:, i, ...], label[:, i, ...]) + current_dice = dice(prediction[:, i, ...], target[:, i, ...]) total_dice += current_dice per_label_dice.append(current_dice.item()) avg_counter += 1 @@ -64,42 +71,49 @@ def multi_class_dice(output, label, params, per_label=False): return total_dice -def multi_class_dice_per_label(output, label, params): +def multi_class_dice_per_label( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> list: """ This function computes a multi-class dice. Args: - output (torch.Tensor): Input data containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. - label (torch.Tensor): Input data containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + prediction (torch.Tensor): Input data containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input data containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. params (dict): The parameter dictionary containing training and data information. Returns: list: The list of per-label dice scores. """ - return multi_class_dice(output, label, params, per_label=True) + return multi_class_dice(prediction, target, params, per_label=True) -def __surface_distances(result, reference, voxelspacing=None, connectivity=1): +def __surface_distances( + prediction: torch.Tensor, + target: torch.Tensor, + voxel_spacing: Optional[Tuple[float]] = None, + connectivity: Optional[int] = 1, +) -> float: """ The distances between the surface voxel of binary objects in result and their nearest partner surface voxel of a binary object in reference. Adapted from https://github.com/loli/medpy/blob/39131b94f0ab5328ab14a874229320efc2f74d98/medpy/metric/binary.py#L1195. Args: - result (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. - reference (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. - voxelspacing (tuple): The size of each voxel, defaults to isotropic spacing of 1mm. - connectivity (int): The connectivity of regions. See scipy.ndimage.generate_binary_structure for more information. + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + voxel_spacing (Optional[Tuple[float]], optional): The voxel spacing. Defaults to None. + connectivity (Optional[int], optional): The voxel connectivity. Defaults to 1. Returns: float: The symmetric Hausdorff Distance between the object(s) in ```result``` and the object(s) in ```reference```. The distance unit is the same as for the spacing of elements along each dimension, which is usually given in mm. """ - result = np.atleast_1d(result.astype(bool)) - reference = np.atleast_1d(reference.astype(bool)) - if voxelspacing is not None: - voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim) - voxelspacing = np.asarray(voxelspacing, dtype=np.float64) - if not voxelspacing.flags.contiguous: - voxelspacing = voxelspacing.copy() + result = np.atleast_1d(prediction.astype(bool)) + reference = np.atleast_1d(target.astype(bool)) + if voxel_spacing is not None: + voxel_spacing = _ni_support._normalize_sequence(voxel_spacing, result.ndim) + voxel_spacing = np.asarray(voxel_spacing, dtype=np.float64) + if not voxel_spacing.flags.contiguous: + voxel_spacing = voxel_spacing.copy() # binary structure footprint = generate_binary_structure(result.ndim, connectivity) @@ -119,13 +133,13 @@ def __surface_distances(result, reference, voxelspacing=None, connectivity=1): # compute average surface distance # Note: scipys distance transform is calculated only inside the borders of the # foreground objects, therefore the input has to be reversed - dt = distance_transform_edt(~reference_border, sampling=voxelspacing) + dt = distance_transform_edt(~reference_border, sampling=voxel_spacing) sds = dt[result_border] return sds -def _nsd_base(a_to_b, b_to_a, threshold): +def _nsd_base(a_to_b: np.ndarray, b_to_a: np.ndarray, threshold: float) -> float: """ This implementation differs from the official surface dice implementation! These two are not comparable!!!!! The normalized surface dice is symmetric, so it should not matter whether a or b is the reference image @@ -154,24 +168,24 @@ def _nsd_base(a_to_b, b_to_a, threshold): def _calculator_jaccard( - inp, - target, - params, - per_label=False, -): + prediction: torch.Tensor, + target: torch.Tensor, + params: dict, + per_label: Optional[bool] = False, +) -> torch.Tensor: """ This function returns sensitivity and specificity. Args: - inp (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. params (dict): The parameter dictionary containing training and data information. - per_label (bool): Whether to return per-label dice scores. + per_label (Optional[bool], optional): Whether to return per-label scores. Defaults to False. Returns: float: The Jaccard score between the object(s) in ```inp``` and the object(s) in ```target```. """ - result_array = _convert_tensor_to_int_label_array(inp) + result_array = _convert_tensor_to_int_label_array(prediction) target_array = _convert_tensor_to_int_label_array(target) jaccard, avg_counter = 0, 0 @@ -194,19 +208,19 @@ def _calculator_jaccard( def _calculator_sensitivity_specificity( - inp, - target, - params, - per_label=False, -): + prediction: torch.Tensor, + target: torch.Tensor, + params: dict, + per_label: Optional[bool] = False, +) -> torch.Tensor: """ This function returns sensitivity and specificity. Args: - inp (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. params (dict): The parameter dictionary containing training and data information. - per_label (bool): Whether to return per-label dice scores. + per_label (Optional[bool], optional): Whether to return per-label scores. Defaults to False. Returns: float, float: The sensitivity and specificity between the object(s) in ```inp``` and the object(s) in ```target```. @@ -237,7 +251,7 @@ def get_sensitivity_and_specificity(result_array, target_array): return Sens, Spec - result_array = _convert_tensor_to_int_label_array(inp) + result_array = _convert_tensor_to_int_label_array(prediction) target_array = _convert_tensor_to_int_label_array(target) sensitivity, specificity, avg_counter = 0, 0, 0 @@ -264,23 +278,25 @@ def get_sensitivity_and_specificity(result_array, target_array): def _calculator_generic_all_surface_distances( - inp, - target, - params, - per_label=False, -): + prediction: torch.Tensor, + target: torch.Tensor, + params: dict, + per_label: Optional[bool] = False, +) -> torch.Tensor: """ This function returns hd100, hd95, and nsd. Args: - inp (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. params (dict): The parameter dictionary containing training and data information. + per_label (Optional[bool], optional): Whether to return per-label scores. Defaults to False. + Returns: float, float, float: The Normalized Surface Dice, 100th percentile Hausdorff Distance, and the 95th percentile Hausdorff Distance. """ - result_array = _convert_tensor_to_int_label_array(inp) + result_array = _convert_tensor_to_int_label_array(prediction) target_array = _convert_tensor_to_int_label_array(target) avg_counter = 0 @@ -330,32 +346,30 @@ def _calculator_generic_all_surface_distances( def _calculator_generic( - inp, - target, - params, - percentile=95, - surface_dice=False, - per_label=False, -): + prediction: torch.Tensor, + target: torch.Tensor, + params: dict, + percentile: int = 95, + surface_dice: Optional[bool] = False, + per_label: Optional[bool] = False, +) -> Union[torch.Tensor, List[float]]: """ Generic Surface Dice (SD)/Hausdorff (HD) Distance calculation from 2 tensors. Compared to the standard Hausdorff Distance, this metric is slightly more stable to small outliers and is commonly used in Biomedical Segmentation challenges. Args: - inp (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. params (dict): The parameter dictionary containing training and data information. - percentile (int, optional): The percentile of surface distances to include during HD calculation. Defaults to 95. - surface_dice (bool, optional): Whether the SD needs to be calculated or not. Defaults to False. - per_label (bool, optional): Whether the hausdorff needs to be calculated per label or not. Defaults to False. + percentile (int, optional): The percentile to calculate the Hausdorff Distance. Defaults to 95. + surface_dice (Optional[bool], optional): Whether to return the surface dice. Defaults to False. + per_label (Optional[bool], optional): Whether to return per-label scores. Defaults to False. - Returns: - float or list: The symmetric Hausdorff Distance or Normalized Surface Distance between the object(s) in ```result``` and the object(s) in ```reference```. The distance unit is the same as for the spacing of elements along each dimension, which is usually given in mm. - See also: - :func:`hd` + Returns: + Union[torch.Tensor, List[float]]: The Normalized Surface Dice, 100th percentile Hausdorff Distance, and the 95th percentile Hausdorff Distance, or the list of per-label scores for each metric. """ _nsd, _hd100, _hd95 = _calculator_generic_all_surface_distances( - inp, target, params, per_label=per_label + prediction, target, params, per_label=per_label ) if surface_dice: return _nsd @@ -365,57 +379,217 @@ def _calculator_generic( return _hd100 -def hd95(inp, target, params): - return _calculator_generic(inp, target, params, percentile=95) +def hd95(prediction: torch.Tensor, target: torch.Tensor, params: dict) -> torch.Tensor: + """ + This function returns the 95th percentile Hausdorff Distance. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + + Returns: + torch.Tensor: The 95th percentile Hausdorff Distance. + """ + return _calculator_generic(prediction, target, params, percentile=95) + + +def hd95_per_label( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> List[float]: + """ + This function returns the per-label 95th percentile Hausdorff Distance. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + + Returns: + List[float]: The list of per-label 95th percentile Hausdorff Distances. + """ + return _calculator_generic( + prediction, target, params, percentile=95, per_label=True + ) + + +def hd100(prediction: torch.Tensor, target: torch.Tensor, params: dict) -> torch.Tensor: + """ + This function returns the 100th percentile Hausdorff Distance. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + Returns: + torch.Tensor: The 100th percentile Hausdorff Distance. + """ + return _calculator_generic(prediction, target, params, percentile=100) -def hd95_per_label(inp, target, params): - return _calculator_generic(inp, target, params, percentile=95, per_label=True) +def hd100_per_label( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> List[float]: + """ + This function returns the per-label 100th percentile Hausdorff Distance. -def hd100(inp, target, params): - return _calculator_generic(inp, target, params, percentile=100) + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + Returns: + List[float]: The list of per-label 100th percentile Hausdorff Distances. + """ + return _calculator_generic( + prediction, target, params, percentile=100, per_label=True + ) -def hd100_per_label(inp, target, params): - return _calculator_generic(inp, target, params, percentile=100, per_label=True) +def nsd(prediction: torch.Tensor, target: torch.Tensor, params: dict) -> torch.Tensor: + """ + This function returns the Normalized Surface Dice. -def nsd(inp, target, params): - return _calculator_generic(inp, target, params, percentile=100, surface_dice=True) + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + Returns: + torch.Tensor: The Normalized Surface Dice. + """ + return _calculator_generic( + prediction, target, params, percentile=100, surface_dice=True + ) + + +def nsd_per_label( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> List[float]: + """ + This function returns the per-label Normalized Surface Dice. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. -def nsd_per_label(inp, target, params): + Returns: + List[float]: The list of per-label Normalized Surface Dice scores. + """ return _calculator_generic( - inp, target, params, percentile=100, per_label=True, surface_dice=True + prediction, target, params, percentile=100, per_label=True, surface_dice=True ) -def sensitivity(inp, target, params): - s, _ = _calculator_sensitivity_specificity(inp, target, params) +def sensitivity( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: + """ + This function returns the sensitivity. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + + Returns: + torch.Tensor: The sensitivity. + """ + s, _ = _calculator_sensitivity_specificity(prediction, target, params) return s -def sensitivity_per_label(inp, target, params): - s, _ = _calculator_sensitivity_specificity(inp, target, params, per_label=True) +def sensitivity_per_label( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> List[float]: + """ + This function returns the per-label sensitivity. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + + Returns: + List[float]: The list of per-label sensitivity scores. + """ + s, _ = _calculator_sensitivity_specificity( + prediction, target, params, per_label=True + ) return s -def specificity_segmentation(inp, target, params): - _, p = _calculator_sensitivity_specificity(inp, target, params) +def specificity_segmentation( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: + """ + This function returns the specificity. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + + Returns: + torch.Tensor: The specificity. + """ + _, p = _calculator_sensitivity_specificity(prediction, target, params) return p -def specificity_segmentation_per_label(inp, target, params): - _, p = _calculator_sensitivity_specificity(inp, target, params, per_label=True) +def specificity_segmentation_per_label( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> List[float]: + """ + This function returns the per-label specificity. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + + Returns: + List[float]: The list of per-label specificity scores. + """ + _, p = _calculator_sensitivity_specificity( + prediction, target, params, per_label=True + ) return p -def jaccard(inp, target, params): - j = _calculator_jaccard(inp, target, params) +def jaccard( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> torch.Tensor: + """ + This function returns the Jaccard score. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + + Returns: + torch.Tensor: The Jaccard score. + """ + j = _calculator_jaccard(prediction, target, params) return j -def jaccard_per_label(inp, target, params): - j = _calculator_jaccard(inp, target, params, per_label=True) +def jaccard_per_label( + prediction: torch.Tensor, target: torch.Tensor, params: dict +) -> List[float]: + """ + This function returns the per-label Jaccard score. + + Args: + prediction (torch.Tensor): Input prediction containing objects. Can be any type but will be converted into binary: background where 0, object everywhere else. + target (torch.Tensor): Input ground truth containing objects. Can be any type but will be converted into binary: binary: background where 0, object everywhere else. + params (dict): The parameter dictionary containing training and data information. + + Returns: + List[float]: The list of per-label Jaccard scores. + """ + j = _calculator_jaccard(prediction, target, params, per_label=True) return j diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index 51105d689..ba1b4113e 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -1,3 +1,4 @@ +from typing import Optional import SimpleITK as sitk import PIL.Image import numpy as np @@ -12,14 +13,16 @@ from GANDLF.utils import get_image_from_tensor -def structural_similarity_index(target, prediction, mask=None) -> torch.Tensor: +def structural_similarity_index( + prediction: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None +) -> torch.Tensor: """ Computes the structural similarity index between the target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. - mask (torch.Tensor, optional): The mask tensor. Defaults to None. + target (torch.Tensor): The target tensor. + mask (Optional[torch.Tensor], optional): The mask tensor. Defaults to None. Returns: torch.Tensor: The structural similarity index. @@ -36,31 +39,36 @@ def structural_similarity_index(target, prediction, mask=None) -> torch.Tensor: return ssim_idx.mean() -def mean_squared_error(target, prediction) -> torch.Tensor: +def mean_squared_error(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Computes the mean squared error between the target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. """ mse = MeanSquaredError() return mse(preds=prediction, target=target) def peak_signal_noise_ratio( - target, prediction, data_range=None, epsilon=None + target: torch.Tensor, + prediction: torch.Tensor, + data_range: Optional[tuple] = None, + epsilon: Optional[float] = None, ) -> torch.Tensor: """ Computes the peak signal to noise ratio between the target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. - data_range (tuple, optional): If not None, this data range (min, max) is used as enumerator instead of computing it from the given data. Defaults to None. - epsilon (float, optional): If not None, this epsilon is added to the denominator of the fraction to avoid infinity as output. Defaults to None. - """ + target (torch.Tensor): The target tensor. + data_range (Optional[tuple], optional): The data range. Defaults to None. + epsilon (Optional[float], optional): The epsilon value. Defaults to None. + Returns: + torch.Tensor: The peak signal to noise ratio. + """ if epsilon == None: psnr = ( PeakSignalNoiseRatio() @@ -80,40 +88,48 @@ def peak_signal_noise_ratio( return 10.0 * torch.log10(((max_v - min_v) ** 2) / (mse + epsilon)) -def mean_squared_log_error(target, prediction) -> torch.Tensor: +def mean_squared_log_error( + prediction: torch.Tensor, target: torch.Tensor +) -> torch.Tensor: """ Computes the mean squared log error between the target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. + + Returns: + torch.Tensor: The mean squared log error. """ mle = MeanSquaredLogError() return mle(preds=prediction, target=target) -def mean_absolute_error(target, prediction) -> torch.Tensor: +def mean_absolute_error(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Computes the mean absolute error between the target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. + + Returns: + torch.Tensor: The mean absolute error. """ mae = MeanAbsoluteError() return mae(preds=prediction, target=target) -def _get_ncc_image(target, prediction) -> sitk.Image: +def _get_ncc_image(prediction: torch.Tensor, target: torch.Tensor) -> sitk.Image: """ Computes normalized cross correlation image between target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. Returns: - torch.Tensor: The normalized cross correlation image. + sitk.Image: The normalized cross correlation image. """ def __convert_to_grayscale(image: sitk.Image) -> sitk.Image: @@ -142,13 +158,13 @@ def __convert_to_grayscale(image: sitk.Image) -> sitk.Image: return correlation_filter.Execute(target_image, pred_image) -def ncc_mean(target, prediction) -> float: +def ncc_mean(prediction: torch.Tensor, target: torch.Tensor) -> float: """ Computes normalized cross correlation mean between target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. Returns: float: The normalized cross correlation mean. @@ -159,13 +175,13 @@ def ncc_mean(target, prediction) -> float: return stats_filter.GetMean() -def ncc_std(target, prediction) -> float: +def ncc_std(prediction: torch.Tensor, target: torch.Tensor) -> float: """ Computes normalized cross correlation standard deviation between target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. Returns: float: The normalized cross correlation standard deviation. @@ -176,13 +192,13 @@ def ncc_std(target, prediction) -> float: return stats_filter.GetSigma() -def ncc_max(target, prediction) -> float: +def ncc_max(prediction: torch.Tensor, target: torch.Tensor) -> float: """ Computes normalized cross correlation maximum between target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. Returns: float: The normalized cross correlation maximum. @@ -193,13 +209,13 @@ def ncc_max(target, prediction) -> float: return stats_filter.GetMaximum() -def ncc_min(target, prediction) -> float: +def ncc_min(prediction: torch.Tensor, target: torch.Tensor) -> float: """ Computes normalized cross correlation minimum between target and prediction. Args: - target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. Returns: float: The normalized cross correlation minimum. diff --git a/GANDLF/models/densenet.py b/GANDLF/models/densenet.py index afb08d959..7abb3f7d5 100644 --- a/GANDLF/models/densenet.py +++ b/GANDLF/models/densenet.py @@ -14,7 +14,7 @@ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, Norm, Co """ Constructor for _DenseLayer class. - Parameters: + Args: num_input_features (int): Number of input channels to the layer. growth_rate (int): Number of output channels of each convolution operation in the layer. bn_size (int): Factor to scale the number of intermediate channels between the 1x1 and 3x3 convolutions. @@ -61,7 +61,7 @@ def forward(self, x): """ Forward pass through the _DenseLayer. - Parameters: + Args: x (torch.Tensor): Input tensor. Returns: @@ -94,7 +94,7 @@ def __init__( """ Constructor for _DenseBlock class. - Parameters: + Args: num_layers (int): Number of dense layers to be added to the block. num_input_features (int): Number of input channels to the block. bn_size (int): Factor to scale the number of intermediate channels between the 1x1 and 3x3 convolutions in each dense layer. diff --git a/GANDLF/models/light_unet_multilayer.py b/GANDLF/models/light_unet_multilayer.py index f8c4faa9f..44030a737 100644 --- a/GANDLF/models/light_unet_multilayer.py +++ b/GANDLF/models/light_unet_multilayer.py @@ -2,6 +2,7 @@ """ Implementation of Light UNet """ +import torch from torch.nn import ModuleList from GANDLF.models.seg_modules.DownsamplingModule import DownsamplingModule @@ -103,18 +104,15 @@ def __init__( self.de[i_lay] = self.converter(self.de[i_lay]).model self.en[i_lay] = self.converter(self.en[i_lay]).model - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Parameters - ---------- - x : Tensor - Should be a 5D Tensor as [batch_size, channels, x_dims, y_dims, z_dims]. + Forward pass of the network. - Returns - ------- - x : Tensor - Returns a 5D Output Tensor as [batch_size, n_classes, x_dims, y_dims, z_dims]. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor. """ y = [] y.append(self.ins(x)) diff --git a/GANDLF/models/modelBase.py b/GANDLF/models/modelBase.py index 7cc4d7e24..226c6f911 100644 --- a/GANDLF/models/modelBase.py +++ b/GANDLF/models/modelBase.py @@ -95,16 +95,16 @@ def __init__(self, parameters): ) ) - def get_final_layer(self, final_convolution_layer): + def get_final_layer(self, final_convolution_layer: str) -> nn.Module: return get_modelbase_final_layer(final_convolution_layer) - def get_norm_type(self, norm_type, dimensions): + def get_norm_type(self, norm_type: str, dimensions: int) -> nn.Module: """ This function gets the normalization type for the model. Args: norm_type (str): Normalization type as a string. - dimensions (str): The dimensionality of the model. + dimensions (int): The dimensionality of the model. Returns: _InstanceNorm or _BatchNorm: The normalization type for the model. @@ -126,7 +126,7 @@ def get_norm_type(self, norm_type, dimensions): return norm_type - def model_depth_check(self, parameters): + def model_depth_check(self, parameters: dict) -> int: """ This function checks if the patch size is large enough for the model. @@ -135,9 +135,6 @@ def model_depth_check(self, parameters): Returns: int: The model depth to use. - - Raises: - AssertionError: If the patch size is not large enough for the model. """ model_depth = checkPatchDimensions( parameters["patch_size"], numlay=parameters["model"]["depth"] diff --git a/GANDLF/models/resnet.py b/GANDLF/models/resnet.py index f1932b815..e8dde75b8 100644 --- a/GANDLF/models/resnet.py +++ b/GANDLF/models/resnet.py @@ -5,6 +5,7 @@ import numpy as np from .modelBase import ModelBase +from GANDLF.utils import getBase2 class ResNet(ModelBase): @@ -487,23 +488,6 @@ def checkPatchDimensions(patch_size, numlay): return numlay -def getBase2(num): - """ - Compute the base 2 logarithm of a number. - - Args: - num (int): the number - - Returns: - int: the base 2 logarithm of the number - """ - base = 0 - while num % 2 == 0: - num = num / 2 - base = base + 1 - return base - - def resnet18(parameters): """ Create a ResNet-18 model with the given parameters. diff --git a/GANDLF/models/seg_modules/IncDropout.py b/GANDLF/models/seg_modules/IncDropout.py index 6bf470b2e..fb08bf62c 100644 --- a/GANDLF/models/seg_modules/IncDropout.py +++ b/GANDLF/models/seg_modules/IncDropout.py @@ -4,34 +4,33 @@ class IncDropout(nn.Module): def __init__( self, - input_channels, - output_channels, - Conv, - Dropout, - InstanceNorm, - dropout_p=0.3, - leakiness=1e-2, - conv_bias=True, - inst_norm_affine=True, - res=False, - lrelu_inplace=True, + input_channels: int, + output_channels: int, + Conv: nn.Module = nn.Conv2d, + Dropout: nn.Module = nn.Dropout2d, + InstanceNorm: nn.Module = nn.InstanceNorm2d, + dropout_p: float = 0.3, + leakiness: float = 1e-2, + conv_bias: bool = True, + inst_norm_affine: bool = True, + res: bool = False, + lrelu_inplace: bool = True, ): """ - Incremental Dropout module with a 1x1 convolutional layer. + Incremental dropout module. - Parameters - ---------- - input_channels (int): Number of input channels. - output_channels (int): Number of output channels. - Conv (torch.nn.Module, optional): Convolutional layer to use. - Dropout (torch.nn.Module, optional): Dropout layer to use. - InstanceNorm (torch.nn.Module, optional): Instance normalization layer to use. - dropout_p (float, optional): Probability of an element to be zeroed. Default is 0.3. - leakiness (float, optional): Negative slope coefficient for LeakyReLU activation. Default is 1e-2. - conv_bias (bool, optional): If True, add a bias term to the convolutional layer. Default is True. - inst_norm_affine (bool, optional): If True, learn two affine parameters per channel in the instance normalization layer. Default is True. - res (bool, optional): If True, add a residual connection to the module. Default is False. - lrelu_inplace (bool, optional): If True, perform the LeakyReLU operation in place. Default is True. + Args: + input_channels (int): Number of input channels. + output_channels (int): Number of output channels. + Conv (nn.Module, optional): The convolutional layer type. Defaults to nn.Conv2d. + Dropout (nn.Module, optional): The dropout layer type. Defaults to nn.Dropout2d. + InstanceNorm (nn.Module, optional): The instance normalization layer type. Defaults to nn.InstanceNorm2d. + dropout_p (float, optional): The dropout probability. Defaults to 0.3. + leakiness (float, optional): The leakiness of the leaky ReLU. Defaults to 1e-2. + conv_bias (bool, optional): The bias in the convolutional layer. Defaults to True. + inst_norm_affine (bool, optional): Whether to use the affine transformation in the instance normalization layer. Defaults to True. + res (bool, optional): Whether to use residual connections. Defaults to False. + lrelu_inplace (bool, optional): Whether to use the inplace version of the leaky ReLU. Defaults to True. """ nn.Module.__init__(self) diff --git a/GANDLF/models/unet_multilayer.py b/GANDLF/models/unet_multilayer.py index 3b735befd..0cae21946 100644 --- a/GANDLF/models/unet_multilayer.py +++ b/GANDLF/models/unet_multilayer.py @@ -2,6 +2,7 @@ """ Implementation of UNet """ +import torch from torch.nn import ModuleList from GANDLF.models.seg_modules.DownsamplingModule import DownsamplingModule @@ -24,15 +25,15 @@ class unet_multilayer(ModelBase): def __init__( self, parameters: dict, - residualConnections=False, + residualConnections:bool=False, ): """ - Parameters - ---------- - parameters (dict): A dictionary containing the model parameters. - residualConnections (bool, optional): A flag to control residual connections in the model, by default False. - """ + The constructor for the unet_multilayer class. + Args: + parameters (dict): A dictionary containing the model parameters. + residualConnections (bool, optional): Flag to control residual connections. Defaults to False. + """ self.network_kwargs = {"res": residualConnections} super(unet_multilayer, self).__init__(parameters) @@ -119,16 +120,15 @@ def __init__( self.de[i_lay] = self.converter(self.de[i_lay]).model self.en[i_lay] = self.converter(self.en[i_lay]).model - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Forward pass of the UNet model. + Forward pass of the U-Net model. Args: - x (Tensor): Should be a 5D Tensor as [batch_size, channels, x_dims, y_dims, z_dims]. + x (torch.Tensor): The input tensor. Returns: - x (Tensor): Returns a 5D Output Tensor as [batch_size, n_classes, x_dims, y_dims, z_dims]. - + torch.Tensor: The output tensor. """ # Store intermediate feature maps @@ -167,10 +167,4 @@ class resunet_multilayer(unet_multilayer): """ def __init__(self, parameters: dict): - """ - Parameters - ---------- - parameters (dict): A dictionary containing the model parameters. - - """ super(resunet_multilayer, self).__init__(parameters, residualConnections=True) diff --git a/GANDLF/models/unetr.py b/GANDLF/models/unetr.py index d656cf421..58a85232e 100644 --- a/GANDLF/models/unetr.py +++ b/GANDLF/models/unetr.py @@ -450,7 +450,7 @@ class _Transformer(nn.Sequential): """ A transformer module that consists of an embedding layer followed by a series of transformer layers. - Parameters: + Args: img_size (tuple): The dimensions of the input image (height, width, depth). patch_size (int): The size of the patches to be extracted from the input image. in_feats (int): The number of input features. @@ -515,7 +515,7 @@ def forward(self, x): """ Processes the input through the transformer and returns the output. - Parameters: + Args: x (tensor): The input tensor. Returns: @@ -569,13 +569,6 @@ def __init__( Args: parameters (dict): A dictionary containing the model parameters. - - Raises: - ------- - AssertionError - If the input image size is not divisible by the patch size in at least 1 dimension, or if the inner patch size is not smaller than the input image. - If the embedding dimension is not divisible by the number of self-attention heads. - """ super(unetr, self).__init__(parameters) @@ -713,21 +706,16 @@ def __init__( ), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Perform the forward pass of the UNet model. + Forward pass of the U-Net model. - Parameters - ---------- - x : torch.Tensor - The input tensor of shape [batch_size, channels, x_dims, y_dims, z_dims]. + Args: + x (torch.Tensor): The input tensor. - Returns - ------- - torch.Tensor - The output tensor of shape [batch_size, n_classes, x_dims, y_dims, z_dims]. + Returns: + torch.Tensor: The output tensor. """ - # Perform transformer encoding of input tensor transformer_out = self.transformer(x) diff --git a/GANDLF/parseConfig.py b/GANDLF/parseConfig.py index 3b51f9d54..d67cd347b 100644 --- a/GANDLF/parseConfig.py +++ b/GANDLF/parseConfig.py @@ -1,7 +1,10 @@ +from typing import Optional, Union from .config_manager import ConfigManager -def parseConfig(config_file_path, version_check_flag=True) -> None: +def parseConfig( + config_file_path: Union[str, dict], version_check_flag: Optional[bool] = True +) -> None: """ This function parses the configuration file and returns a dictionary of parameters. diff --git a/GANDLF/training_manager.py b/GANDLF/training_manager.py index 986b7f1f0..41c188d74 100644 --- a/GANDLF/training_manager.py +++ b/GANDLF/training_manager.py @@ -7,7 +7,14 @@ from GANDLF.utils import get_dataframe -def TrainingManager(dataframe, outputDir, parameters, device, resume, reset): +def TrainingManager( + dataframe: pd.DataFrame, + outputDir: str, + parameters: dict, + device: str, + resume: bool, + reset: bool, +) -> None: """ This is the training manager that ties all the training functionality together @@ -303,22 +310,22 @@ def TrainingManager(dataframe, outputDir, parameters, device, resume, reset): def TrainingManager_split( - dataframe_train, - dataframe_validation, - dataframe_testing, - outputDir, - parameters, - device, - resume, - reset, + dataframe_train: pd.DataFrame, + dataframe_validation: pd.DataFrame, + dataframe_testing: pd.DataFrame, + outputDir: str, + parameters: dict, + device: str, + resume: bool, + reset: bool, ): """ This is the training manager that ties all the training functionality together Args: - dataframe_train (pandas.DataFrame): The training data from CSV. - dataframe_validation (pandas.DataFrame): The validation data from CSV. - dataframe_testing (pandas.DataFrame): The testing data from CSV. + dataframe_train (pd.DataFrame): The training data from CSV. + dataframe_validation (pd.DataFrame): The validation data from CSV. + dataframe_testing (pd.DataFrame): The testing data from CSV. outputDir (str): The main output directory. parameters (dict): The parameters dictionary. device (str): The device to perform computations on. diff --git a/GANDLF/utils/__init__.py b/GANDLF/utils/__init__.py index 39173e2b9..357050dd0 100644 --- a/GANDLF/utils/__init__.py +++ b/GANDLF/utils/__init__.py @@ -9,6 +9,7 @@ perform_sanity_check_on_subject, write_training_patches, get_correct_padding_size, + applyCustomColorMap, ) from .tensor import ( @@ -52,6 +53,7 @@ set_determinism, print_and_format_metrics, determine_classification_task_type, + getBase2, ) from .modelio import ( diff --git a/GANDLF/utils/generic.py b/GANDLF/utils/generic.py index 52f6a33d9..6b8fed638 100644 --- a/GANDLF/utils/generic.py +++ b/GANDLF/utils/generic.py @@ -6,7 +6,7 @@ import SimpleITK as sitk from contextlib import contextmanager, redirect_stderr, redirect_stdout from os import devnull -from typing import Dict, Any, Union +from typing import Optional, Union @contextmanager @@ -17,13 +17,13 @@ def suppress_stdout_stderr(): yield (err, out) -def checkPatchDivisibility(patch_size, number=16): +def checkPatchDivisibility(patch_size: np.ndarray, number: Optional[int] = 16) -> bool: """ This function checks the divisibility of a numpy array or integer for architectural integrity Args: - patch_size (numpy.array): The patch size for checking. - number (int, optional): The number to check divisibility for. Defaults to 16. + patch_size (np.ndarray): The patch size for checking. + number (Optional[int], optional): The number to check divisibility for. Defaults to 16. Returns: bool: If all elements of array are divisible or not, after taking 2D patches into account. @@ -49,21 +49,7 @@ def checkPatchDivisibility(patch_size, number=16): return True -def determine_classification_task_type( - params: Dict[str, Union[Dict[str, Any], Any]] -) -> str: - """Determine the task (binary or multiclass) from the model config. - Args: - params (dict): The parameter dictionary containing training and data information. - - Returns: - str: A string that denotes the classification task type. - """ - task = "binary" if params["model"]["num_classes"] == 2 else "multiclass" - return task - - -def get_date_time(): +def get_date_time() -> str: """ Get a well-parsed date string @@ -74,7 +60,7 @@ def get_date_time(): return now -def get_unique_timestamp(): +def get_unique_timestamp() -> str: """ Get a well-parsed timestamp string to be used for unique filenames @@ -85,7 +71,7 @@ def get_unique_timestamp(): return now -def get_filename_extension_sanitized(filename): +def get_filename_extension_sanitized(filename: str) -> str: """ This function returns the extension of the filename with leading and trailing characters removed. Args: @@ -100,7 +86,7 @@ def get_filename_extension_sanitized(filename): return ext -def parse_version(version_string): +def parse_version(version_string: str) -> int: """ Parses version string, discards last identifier (NR/alpha/beta) and returns an integer for comparison. @@ -117,7 +103,7 @@ def parse_version(version_string): return int("".join(version_string_split)) -def version_check(version_from_config, version_to_check): +def version_check(version_from_config: str, version_to_check: str) -> bool: """ This function checks if the version of the config file is compatible with the version of the code. @@ -141,12 +127,12 @@ def version_check(version_from_config, version_to_check): return True -def checkPatchDimensions(patch_size, numlay): +def checkPatchDimensions(patch_size: np.ndarray, numlay: int) -> int: """ This function checks the divisibility of a numpy array or integer for architectural integrity Args: - patch_size (numpy.array): The patch size for checking. + patch_size (np.ndarray): The patch size for checking. number (int, optional): The number to check divisibility for. Defaults to 16. Returns: @@ -173,7 +159,16 @@ def checkPatchDimensions(patch_size, numlay): return int(np.min(layers)) -def getBase2(num): +def getBase2(num: int) -> int: + """ + Compute the base 2 logarithm of a number. + + Args: + num (int): the number + + Returns: + int: the base 2 logarithm of the number + """ # helper for checkPatchDimensions (returns the largest multiple of 2 that num is evenly divisible by) base = 0 while num % 2 == 0: @@ -182,29 +177,37 @@ def getBase2(num): return base -def get_array_from_image_or_tensor(input_tensor_or_image): +def get_array_from_image_or_tensor( + input_tensor_or_image: Union[torch.Tensor, sitk.Image] +) -> np.ndarray: """ - This function returns the numpy array from a tensor or image. + This function returns the numpy array from a torch.Tensor or sitk.Image. + Args: - input_tensor_or_image (torch.Tensor or sitk.Image): The input tensor or image. + input_tensor_or_image (Union[torch.Tensor, sitk.Image]): The input tensor or image. + Returns: - numpy.array: The numpy array from the tensor or image. + np.ndarray: The numpy array. """ + assert isinstance( + input_tensor_or_image, (torch.Tensor, sitk.Image, np.ndarray) + ), "Input must be a torch.Tensor or sitk.Image or np.ndarray, but got " + str( + type(input_tensor_or_image) + ) if isinstance(input_tensor_or_image, torch.Tensor): return input_tensor_or_image.detach().cpu().numpy() elif isinstance(input_tensor_or_image, sitk.Image): return sitk.GetArrayFromImage(input_tensor_or_image) elif isinstance(input_tensor_or_image, np.ndarray): return input_tensor_or_image - else: - raise ValueError("Input must be a torch.Tensor or sitk.Image or np.ndarray") -def set_determinism(seed=42): +def set_determinism(seed: Optional[int] = 42) -> None: """ - This function controls the randomness of the program. It sets the seed both for torch and numpy. + This function sets the determinism for the random number generators. + Args: - seed (int, optional): Seed to set. Defaults to 42. + seed (Optional[int], optional): The seed for the random number generators. Defaults to 42. """ random.seed(seed) np.random.seed(seed) @@ -220,12 +223,12 @@ def set_determinism(seed=42): def print_and_format_metrics( - cohort_level_metrics, - sample_level_metrics, - metrics_dict_from_parameters, - mode, - length_of_dataloader, -): + cohort_level_metrics: dict, + sample_level_metrics: dict, + metrics_dict_from_parameters: dict, + mode: str, + length_of_dataloader: int, +) -> dict: """ This function prints and formats the metrics. @@ -240,14 +243,15 @@ def print_and_format_metrics( dict: The metrics dictionary populated with the metrics. """ - def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict: + def __update_metric_from_list_to_single_string(input_metrics_dict: dict) -> dict: """ - Helper function updates the metrics dictionary to have a single string for each metric. + Helper function to update the metric from list to single string. Args: input_metrics_dict (dict): The input metrics dictionary. + Returns: - dict: The updated metrics dictionary. + dict: The output metrics dictionary. """ print(input_metrics_dict) output_metrics_dict = deepcopy(input_metrics_dict) @@ -283,12 +287,10 @@ def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict: return output_metrics_dict -def define_average_type_key( - params: Dict[str, Union[Dict[str, Any], Any]], metric_name: str -) -> str: - """Determine if the the 'average' filed is defined in the metric config. - If not, fallback to the default 'macro' - values. +def define_average_type_key(params: dict, metric_name: str) -> str: + """ + Determine the average type key from the metric config. + Args: params (dict): The parameter dictionary containing training and data information. metric_name (str): The name of the metric. @@ -300,15 +302,30 @@ def define_average_type_key( return average_type_key -def define_multidim_average_type_key(params, metric_name) -> str: - """Determine if the the 'multidim_average' filed is defined in the metric config. - If not, fallback to the default 'global'. +def define_multidim_average_type_key(params: dict, metric_name: str) -> str: + """ + Determine the multidimensional average type key from the metric config. + Args: params (dict): The parameter dictionary containing training and data information. metric_name (str): The name of the metric. Returns: - str: The average type key. + str: The multidimensional average type key. """ average_type_key = params["metrics"][metric_name].get("multidim_average", "global") return average_type_key + + +def determine_classification_task_type(params: dict) -> str: + """ + This function determines the classification task type from the parameters. + + Args: + params (dict): The parameter dictionary containing training and data information. + + Returns: + str: The classification task type (binary or multiclass). + """ + task = "binary" if params["model"]["num_classes"] == 2 else "multiclass" + return task diff --git a/GANDLF/utils/handle_collisions.py b/GANDLF/utils/handle_collisions.py index 9eb7e7848..77a253d97 100644 --- a/GANDLF/utils/handle_collisions.py +++ b/GANDLF/utils/handle_collisions.py @@ -1,10 +1,13 @@ import os +from typing import Tuple import pandas as pd -def handle_collisions(df, headers, output_path): +def handle_collisions(df: pd.DataFrame, headers: dict, output_path: str) -> Tuple[bool, pd.DataFrame]: """ + This function checks for collisions in the subject IDs and updates the subject IDs in the dataframe to avoid collisions. + This function takes a dataframe as input and checks if there are any pairs of subject IDs that are similar to each other. If it finds any such pairs, it renames the subject IDs by adding a suffix of '_v1', '_v2', or '_v3' to differentiate them. The function then creates @@ -13,9 +16,12 @@ def handle_collisions(df, headers, output_path): of any subject ID collisions that were detected during the process. Args: - df (pandas.DataFrame): The input dataframe. - headers (dict): The parsed headers. - output_path (str): The output directory. + df (pd.DataFrame): The dataframe containing the subject IDs to be checked for collisions. + headers (dict): The headers dictionary containing the subjectIDHeader key. + output_path (str): The path to the output directory where the updated dataframe and the collision.csv file will be written. + + Returns: + Tuple[bool, pd.DataFrame]: A tuple containing a boolean indicating whether any collisions were found, and the updated dataframe. """ # Find the subjectID header diff --git a/GANDLF/utils/imaging.py b/GANDLF/utils/imaging.py index 37e88a98c..05c2cf4ce 100644 --- a/GANDLF/utils/imaging.py +++ b/GANDLF/utils/imaging.py @@ -1,67 +1,79 @@ -from typing import Union +from typing import List, Optional, Tuple, Union import os, pathlib, math, copy +from enum import Enum import numpy as np import SimpleITK as sitk import torchio +import cv2 from .generic import get_filename_extension_sanitized def resample_image( - img, spacing, size=None, interpolator=sitk.sitkLinear, outsideValue=0 -): + input_image: sitk.Image, + spacing: Union[np.ndarray, List[float], Tuple[float]], + size: Optional[Union[np.ndarray, List[float], Tuple[float]]] = None, + interpolator: Optional[Enum] = sitk.sitkLinear, + outsideValue: Optional[int] = 0, +) -> sitk.Image: """ - Resample image to certain spacing and size. + This function resamples the input image based on the spacing and size. + Args: - img (SimpleITK.Image): The input image to resample. - spacing (list): List of length 3 indicating the voxel spacing as [x, y, z]. - size (list, optional): List of length 3 indicating the number of voxels per dim [x, y, z], which will use compute the appropriate size based on the spacing. Defaults to []. - interpolator (SimpleITK.InterpolatorEnum, optional): The interpolation type to use. Defaults to SimpleITK.sitkLinear. - origin (list, optional): The location in physical space representing the [0,0,0] voxel in the input image. Defaults to [0,0,0]. - outsideValue (int, optional): value used to pad are outside image. Defaults to 0. - Raises: - Exception: Spacing/resolution mismatch. - Exception: Size mismatch. + input_image (sitk.Image): The input image to be resampled. + spacing (Union[np.ndarray, List[float], Tuple[float]]): The desired spacing for the resampled image. + size (Optional[Union[np.ndarray, List[float], Tuple[float]]], optional): The desired size for the resampled image. Defaults to None. + interpolator (Optional[Enum], optional): The desired interpolator. Defaults to sitk.sitkLinear. + outsideValue (Optional[int], optional): The value to be used for the outside of the image. Defaults to 0. + Returns: - SimpleITK.Image: The resampled input image. + sitk.Image: The resampled image. """ - if len(spacing) != img.GetDimension(): - raise Exception("len(spacing) != " + str(img.GetDimension())) + assert ( + len(spacing) == input_image.GetDimension() + ), "The spacing dimension is inconsistent with the input dataset, please check parameters." # Set Size if size is None: - inSpacing = img.GetSpacing() - inSize = img.GetSize() + inSpacing = input_image.GetSpacing() + inSize = input_image.GetSize() size = [ int(math.ceil(inSize[i] * (inSpacing[i] / spacing[i]))) - for i in range(img.GetDimension()) + for i in range(input_image.GetDimension()) ] - else: - if len(size) != img.GetDimension(): - raise Exception("len(size) != " + str(img.GetDimension())) + + assert ( + len(size) == input_image.GetDimension() + ), "The size dimension is inconsistent with the input dataset, please check parameters." # Resample input image return sitk.Resample( - img, + input_image, size, sitk.Transform(), interpolator, - img.GetOrigin(), + input_image.GetOrigin(), spacing, - img.GetDirection(), + input_image.GetDirection(), outsideValue, ) -def resize_image(input_image, output_size, interpolator=sitk.sitkLinear): +def resize_image( + input_image: sitk.Image, + output_size: Union[np.ndarray, list, tuple], + interpolator: Optional[Enum] = sitk.sitkLinear, +) -> sitk.Image: """ - This function resizes the input image based on the output size and interpolator. + This function resizes the input image based on the output size. + Args: - input_image (SimpleITK.Image): The input image to be resized. - output_size (Union[numpy.ndarray, list, tuple]): The output size to resample input_image to. - interpolator (SimpleITK.sitkInterpolator): The desired interpolator. + input_image (sitk.Image): The input image to be resized. + output_size (Union[np.ndarray, list, tuple]): The desired output size for the resized image. + interpolator (Optional[Enum], optional): The desired interpolator. Defaults to sitk.sitkLinear. + Returns: - SimpleITK.Image: The output image after resizing. + sitk.Image: The resized image. """ output_size_parsed = None inputSize = input_image.GetSize() @@ -87,15 +99,21 @@ def resize_image(input_image, output_size, interpolator=sitk.sitkLinear): ) -def softer_sanity_check(base_property, new_property, threshold=0.00001): +def softer_sanity_check( + base_property: Union[np.ndarray, List[float], Tuple[float]], + new_property: Union[np.ndarray, List[float], Tuple[float]], + threshold: Optional[float] = 0.00001, +) -> bool: """ - This function checks if the new property is within the threshold of the base property. + This function performs a softer sanity check on the input properties. + Args: - base_property (float): The base property to check. - new_property (float): The new property to check - threshold (float, optional): The threshold to check if the new property is within the base property. Defaults to 0.00001. + base_property (Union[np.ndarray, List[float], Tuple[float]]): The base property. + new_property (Union[np.ndarray, List[float], Tuple[float]]): The new property. + threshold (Optional[float], optional): The threshold for comparison. Defaults to 0.00001. + Returns: - bool: Whether the new property is within the threshold of the base property. + bool: True if the properties are consistent within the threshold. """ arr_1 = np.array(base_property) arr_2 = np.array(new_property) @@ -108,21 +126,16 @@ def softer_sanity_check(base_property, new_property, threshold=0.00001): return result -def perform_sanity_check_on_subject(subject, parameters): +def perform_sanity_check_on_subject(subject: torchio.Subject, parameters: dict) -> bool: """ - This function performs sanity check on the subject to ensure presence of consistent header information WITHOUT loading images into memory. + This function performs a sanity check on the image modalities in input subject to ensure that they are consistent. Args: subject (torchio.Subject): The input subject. parameters (dict): The parameters passed by the user yaml. Returns: - bool: True if everything is okay. - - Raises: - ValueError: Dimension mismatch in the images. - ValueError: Origin mismatch in the images. - ValueError: Orientation mismatch in the images. + bool: True if the sanity check passes. """ # read the first image and save that for comparison file_reader_base = None @@ -131,12 +144,14 @@ def perform_sanity_check_on_subject(subject, parameters): if parameters["headers"]["labelHeader"] is not None: list_for_comparison.append("label") - def _get_itkimage_or_filereader(subject_str_key): + def _get_itkimage_or_filereader( + subject_str_key: Union[str, sitk.Image] + ) -> Union[sitk.ImageFileReader, sitk.Image]: """ Helper function to get the itk image or file reader from the subject. Args: - subject_str_key (Union[str, sitk.Image]): The subject string key. + subject_str_key (Union[str, sitk.Image]): The subject string key or itk image. Returns: Union[sitk.ImageFileReader, sitk.Image]: The itk image or file reader. @@ -158,48 +173,44 @@ def _get_itkimage_or_filereader(subject_str_key): file_reader_current = _get_itkimage_or_filereader(subject[str(key)]) # this check needs to be absolute - if ( + assert ( file_reader_base.GetDimension() - != file_reader_current.GetDimension() - ): - raise ValueError( - "Dimensions for Subject '" - + subject["subject_id"] - + "' are not consistent." - ) + == file_reader_current.GetDimension() + ), ( + "Dimensions for Subject '" + + subject["subject_id"] + + "' are not consistent." + ) # other checks can be softer - if not softer_sanity_check( + assert softer_sanity_check( file_reader_base.GetOrigin(), file_reader_current.GetOrigin() - ): - raise ValueError( - "Origin for Subject '" - + subject["subject_id"] - + "' are not consistent." - ) - - if not softer_sanity_check( + ), ( + "Origin for Subject '" + + subject["subject_id"] + + "' are not consistent." + ) + + assert softer_sanity_check( file_reader_base.GetDirection(), file_reader_current.GetDirection() - ): - raise ValueError( - "Orientation for Subject '" - + subject["subject_id"] - + "' are not consistent." - ) - - if not softer_sanity_check( + ), ( + "Orientation for Subject '" + + subject["subject_id"] + + "' are not consistent." + ) + + assert softer_sanity_check( file_reader_base.GetSpacing(), file_reader_current.GetSpacing() - ): - raise ValueError( - "Spacing for Subject '" - + subject["subject_id"] - + "' are not consistent." - ) + ), ( + "Spacing for Subject '" + + subject["subject_id"] + + "' are not consistent." + ) return True -def write_training_patches(subject, params): +def write_training_patches(subject: torchio.Subject, params: dict) -> None: """ This function writes the training patches to disk. @@ -241,12 +252,14 @@ def write_training_patches(subject, params): ) -def get_correct_padding_size(patch_size: Union[list, tuple], model_dimension: int): +def get_correct_padding_size( + patch_size: Union[List[int], Tuple[int]], model_dimension: int +): """ This function returns the correct padding size based on the patch size and overlap. Args: - patch_size (Union[list, tuple]): The patch size. + patch_size (Union[List[int], Tuple[int]]): The patch size. model_dimension (int): The model dimension. Returns: @@ -258,3 +271,21 @@ def get_correct_padding_size(patch_size: Union[list, tuple], model_dimension: in psize_pad[-1] = 0 if psize_pad[-1] == 1 else psize_pad[-1] return psize_pad + + +def applyCustomColorMap(im_gray: np.ndarray) -> np.ndarray: + """ + Internal function to apply a custom color map to the input image. + + Args: + im_gray (np.ndarray): The input image. + + Returns: + np.ndarray: The image with the custom color map applied. + """ + img_bgr = cv2.cvtColor(im_gray.astype(np.uint8), cv2.COLOR_BGR2RGB) + lut = np.zeros((256, 1, 3), dtype=np.uint8) + lut[:, 0, 0] = np.zeros((256)).tolist() + lut[:, 0, 1] = np.zeros((256)).tolist() + lut[:, 0, 2] = np.arange(0, 256, 1).tolist() + return cv2.LUT(img_bgr, lut) diff --git a/GANDLF/utils/modelbase.py b/GANDLF/utils/modelbase.py index 3abe64a06..da216879f 100644 --- a/GANDLF/utils/modelbase.py +++ b/GANDLF/utils/modelbase.py @@ -1,16 +1,17 @@ +from typing import Union import torch import torch.nn.functional as F -def get_modelbase_final_layer(final_convolution_layer): +def get_modelbase_final_layer(final_convolution_layer: str) -> Union[object, None]: """ - This function gets the final layer of the model. + This function returns the final convolution layer based on the input string. Args: - final_convolution_layer (str): The final layer of the model as a string. + final_convolution_layer (str): The string representing the final convolution layer. Returns: - Functional: sigmoid, softmax, or None + Union[object, None]: The final convolution layer. """ none_list = [ "none", diff --git a/GANDLF/utils/modelio.py b/GANDLF/utils/modelio.py index 23556b206..20644e2ae 100644 --- a/GANDLF/utils/modelio.py +++ b/GANDLF/utils/modelio.py @@ -1,7 +1,7 @@ import hashlib import os import subprocess -from typing import Any, Dict +from typing import Any, Dict, Optional, Tuple import torch @@ -30,7 +30,9 @@ initial_model_path_end = "_initial.pth.tar" -def optimize_and_save_model(model, params, path, onnx_export=True): +def optimize_and_save_model( + model: torch.nn.Module, params: dict, path: str, onnx_export: Optional[bool] = True +) -> None: """ Perform post-training optimization and save it to a file. @@ -38,7 +40,7 @@ def optimize_and_save_model(model, params, path, onnx_export=True): model (torch.nn.Module): Trained torch model. params (dict): The parameter dictionary. path (str): The path to save the model dictionary to. - onnx_export (bool): Whether to export to ONNX and OpenVINO. + onnx_export (Optional[bool]): Whether to export to ONNX and OpenVINO. Defaults to True. """ # Check if ONNX export is enabled in the parameter dictionary onnx_export = params["model"].get("onnx_export", onnx_export) @@ -135,8 +137,8 @@ def save_model( model: torch.nn.Module, params: Dict[str, Any], path: str, - onnx_export: bool = True, -): + onnx_export: Optional[bool] = True, +) -> None: """ Save the model dictionary to a file. @@ -145,7 +147,7 @@ def save_model( model (torch.nn.Module): Trained torch model. params (dict): The parameter dictionary. path (str): The path to save the model dictionary to. - onnx_export (bool): Whether to export to ONNX and OpenVINO. + onnx_export (Optional[bool]): Whether to export to ONNX and OpenVINO. Defaults to True. """ model_dict["timestamp"] = get_unique_timestamp() model_dict["timestamp_hash"] = hashlib.sha256( @@ -171,7 +173,7 @@ def save_model( def load_model( - path: str, device: torch.device, full_sanity_check: bool = True + path: str, device: torch.device, full_sanity_check: Optional[bool] = True ) -> Dict[str, Any]: """ Load a model dictionary from a file. @@ -179,7 +181,7 @@ def load_model( Args: path (str): The path to save the model dictionary to. device (torch.device): The device to run the model on. - full_sanity_check (bool): Whether to run full sanity checking on the model. + full_sanity_check (Optional[bool]): Whether to perform a full sanity check. Defaults to True. Returns: dict: Model dictionary containing model parameters and metadata. @@ -191,37 +193,35 @@ def load_model( incomplete_keys = [ key for key in model_dict_full.keys() if key not in model_dict.keys() ] - if len(incomplete_keys) > 0: - raise RuntimeWarning( - "Model dictionary is incomplete; the following keys are missing:", - incomplete_keys, - ) + assert ( + len(incomplete_keys) == 0 + ), "Model dictionary is incomplete; the following keys are missing: " + str( + incomplete_keys + ) # check if required keys are absent, and if so raise an error incomplete_required_keys = [ key for key in model_dict_required.keys() if key not in model_dict.keys() ] - if len(incomplete_required_keys) > 0: - raise KeyError( - "Model dictionary is incomplete; the following keys are missing:", - incomplete_required_keys, - ) + assert ( + len(incomplete_required_keys) == 0 + ), "Model dictionary is incomplete; the following keys are missing: " + str( + incomplete_required_keys + ) return model_dict -def load_ov_model(path: str, device: str = "CPU"): +def load_ov_model(path: str, device: Optional[str] = "CPU") -> Tuple[Any, Any, Any]: """ Load an OpenVINO IR model from an .xml file. Args: path (str): The path to the OpenVINO .xml file. - device (str): The device to run inference, can be "CPU", "GPU" or "MULTI:CPU,GPU". Default to be "CPU". + device (Optional[str]): The device to run the model on, can be "CPU", "GPU" or "MULTI:CPU,GPU". Defaults to "CPU". Returns: - exec_net (OpenVINO executable net): executable OpenVINO model. - input_blob (str): Input name. - output_blob (str): Output name. + Tuple[Any, Any, Any]: The compiled OpenVINO model, input layer name, and output layer name. """ try: diff --git a/GANDLF/utils/parameter_processing.py b/GANDLF/utils/parameter_processing.py index e41ac7f7a..42ce1d140 100644 --- a/GANDLF/utils/parameter_processing.py +++ b/GANDLF/utils/parameter_processing.py @@ -1,8 +1,9 @@ +from torch.utils.data import DataLoader from GANDLF.utils.modelbase import get_modelbase_final_layer from GANDLF.metrics import surface_distance_ids -def populate_header_in_parameters(parameters, headers): +def populate_header_in_parameters(parameters: dict, headers: dict) -> dict: """ This function populates the parameters with information from the header in a common manner @@ -49,13 +50,13 @@ def populate_header_in_parameters(parameters, headers): return parameters -def find_problem_type(parameters, model_final_layer): +def find_problem_type(parameters: dict, model_final_layer: str) -> str: """ This function determines the type of problem at hand - regression, classification or segmentation Args: parameters (dict): The parameters passed by the user yaml. - model_final_layer (model_final_layer): The final layer of the model. If None, the model is for regression. + model_final_layer (str): The final layer of the model. If None, the model is for regression. Returns: str: The problem type (regression/classification/segmentation). @@ -84,7 +85,7 @@ def find_problem_type(parameters, model_final_layer): return "segmentation" -def find_problem_type_from_parameters(parameters): +def find_problem_type_from_parameters(parameters: dict) -> str: """ This function determines the type of problem at hand - regression, classification or segmentation @@ -124,12 +125,12 @@ def find_problem_type_from_parameters(parameters): return "regression" -def populate_channel_keys_in_params(data_loader, parameters): +def populate_channel_keys_in_params(data_loader: DataLoader, parameters: dict): """ Function to read channel key information from specified data loader Args: - data_loader (torch.DataLoader): The data loader to query key information from. + data_loader (DataLoader): The data loader to query key information from. parameters (dict): The parameters passed by the user yaml. Returns: diff --git a/GANDLF/utils/tensor.py b/GANDLF/utils/tensor.py index bbaa5ae03..bc4d6417c 100644 --- a/GANDLF/utils/tensor.py +++ b/GANDLF/utils/tensor.py @@ -1,12 +1,14 @@ import os, sys -from typing import Union +from typing import List, Optional, Tuple, Union from pandas.util import hash_pandas_object import numpy as np import SimpleITK as sitk +import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader import torchio +import torchmetrics from tqdm import tqdm from torchinfo import summary from GANDLF.utils.generic import get_array_from_image_or_tensor @@ -15,13 +17,15 @@ special_cases_to_check = ["||"] -def one_hot(segmask_tensor, class_list): +def one_hot( + segmask_tensor: torch.Tensor, class_list: Union[List[int], List[str]] +) -> torch.Tensor: """ This function creates a one-hot-encoded mask from the segmentation mask Tensor and specified class list Args: segmask_tensor (torch.Tensor): The segmentation mask Tensor. - class_list (list): The list of classes based on which one-hot encoding needs to happen. + class_list (Union[List[int], List[str]]): The list of classes based on which one-hot encoding needs to happen. Returns: torch.Tensor: The one-hot encoded torch.Tensor @@ -80,13 +84,13 @@ def one_hot(segmask_tensor, class_list): return batch_stack -def reverse_one_hot(predmask_tensor, class_list): +def reverse_one_hot(predmask_tensor: torch.Tensor, class_list: Union[List[int], List[str]]) -> np.array: """ This function creates a full segmentation mask Tensor from a one-hot-encoded mask and specified class list Args: predmask_tensor (torch.Tensor): The predicted segmentation mask Tensor. - class_list (list): The list of classes based on which one-hot encoding needs to happen. + class_list (Union[List[int], List[str]]): The list of classes based on which one-hot encoding needs to happen. Returns: numpy.array: The final mask as numpy array. @@ -125,7 +129,9 @@ def reverse_one_hot(predmask_tensor, class_list): return final_mask -def send_model_to_device(model, amp, device, optimizer): +def send_model_to_device( + model: torch.nn.Module, amp: bool, device: str, optimizer: torch.optim +) -> Union[torch.nn.Module, bool, torch.device, int]: """ This function reads the environment variable(s) and send model to correct device @@ -141,10 +147,10 @@ def send_model_to_device(model, amp, device, optimizer): torch.device: Device type. """ if device == "cuda": - if os.environ.get("CUDA_VISIBLE_DEVICES") is None: - sys.exit( - "Please set the environment variable 'CUDA_VISIBLE_DEVICES' correctly before trying to run GANDLF on GPU" - ) + assert torch.cuda.is_available(), "CUDA is either not available or not enabled" + assert ( + os.environ.get("CUDA_VISIBLE_DEVICES") is not None + ), "CUDA_VISIBLE_DEVICES is not set" dev = os.environ.get("CUDA_VISIBLE_DEVICES") # multi-gpu support @@ -212,13 +218,13 @@ def send_model_to_device(model, amp, device, optimizer): return model, amp, device, dev -def get_model_dict(model, device_id): +def get_model_dict(model: torch.nn.Module, device_id: Union[str, List[str]]) -> dict: """ This function returns the model dictionary Args: model (torch.nn.Module): The model for which the dictionary is to be returned. - device_id (Union[str, list]): The device id as string or list. + device_id (Union[str, List[str]]): The device id as string or list of devices. Returns: dict: The model dictionary. @@ -235,7 +241,9 @@ def get_model_dict(model, device_id): return model_dict -def get_class_imbalance_weights_classification(training_df, params): +def get_class_imbalance_weights_classification( + training_df: pd.DataFrame, params: dict +) -> Tuple[dict, dict, dict]: """ This function calculates the penalty used for loss functions in multi-class problems. It looks at the column "valuesToPredict" and identifies unique classes, fetches the class distribution @@ -247,7 +255,7 @@ def get_class_imbalance_weights_classification(training_df, params): parameters (dict) : The parameters passed by the user yaml. Returns: - dict: The penalty weights for different classes under consideration for classification. + Tuple[dict, dict, dict]: The penalty weights, sampling weights, and class weights for different classes under consideration. """ predictions_array = ( training_df[training_df.columns[params["headers"]["predictionHeaders"]]] @@ -287,7 +295,9 @@ def get_class_imbalance_weights_classification(training_df, params): return penalty_dict, None, weight_dict -def get_class_imbalance_weights_segmentation(training_data_loader, parameters): +def get_class_imbalance_weights_segmentation( + training_data_loader: DataLoader, parameters: dict +) -> Tuple[dict, dict, dict]: """ This function calculates the penalty that is used for validation loss in multi-class problems @@ -296,7 +306,7 @@ def get_class_imbalance_weights_segmentation(training_data_loader, parameters): parameters (dict): The parameters passed by the user yaml. Returns: - dict: The penalty weights for different classes under consideration. + Tuple[dict, dict, dict]: The penalty weights, sampling weights, and class weights for different classes under consideration. """ abs_dict = {} # absolute counts for each class weights_dict = {} # average for "weighted averaging" @@ -353,16 +363,18 @@ def get_class_imbalance_weights_segmentation(training_data_loader, parameters): return penalty_dict, penalty_dict, weights_dict -def get_class_imbalance_weights(training_df, params): +def get_class_imbalance_weights( + training_df: pd.DataFrame, params: dict +) -> Tuple[dict, dict, dict]: """ - This is a wrapper function that calculates the penalty used for loss functions in classification/segmentation problems. + This function calculates the penalty that is used for validation loss in multi-class problems Args: - training_Df (pd.DataFrame): The training data frame. - parameters (dict) : The parameters passed by the user yaml. + training_df (pd.DataFrame): The training data frame. + params (dict): The parameters passed by the user yaml. Returns: - float, float: The penalty and class weights for different classes under consideration for classification. + Tuple[dict, dict, dict]: The penalty weights, sampling weights, and class weights for different classes under consideration. """ penalty_weights, sampling_weights, class_weights = None, None, None if params["weighted_loss"] or params["patch_sampler"]["biased_sampling"]: @@ -426,7 +438,7 @@ def get_class_imbalance_weights(training_df, params): return penalty_weights, sampling_weights, class_weights -def get_linear_interpolation_mode(dimensionality): +def get_linear_interpolation_mode(dimensionality: int) -> str: """ Get linear interpolation mode. @@ -447,18 +459,21 @@ def get_linear_interpolation_mode(dimensionality): def print_model_summary( - model, input_batch_size, input_num_channels, input_patch_size, device=None -): + model: torch.nn.Module, + input_batch_size: int, + input_num_channels: int, + input_patch_size: tuple, + device: Optional[torch.device] = None, +) -> None: """ - _summary_ - Estimates the size of PyTorch models in memory - for a given input size + This function prints the model summary. + Args: - model (torch.nn.Module): The model to be summarized. - input_batch_size (int): The batch size of the input. - input_num_channels (int): The number of channels of the input. - input_patch_size (tuple): The patch size of the input. - device (torch.device, optional): The device on which the model is run. Defaults to None. + model (torch.nn.Module): The model for which the summary is to be printed. + input_batch_size (int): The input batch size. + input_num_channels (int): The input number of channels. + input_patch_size (tuple): The input patch size. + device (Optional[torch.device], optional): The device. Defaults to None. """ input_size = (input_batch_size, input_num_channels) + tuple(input_patch_size) if input_size[-1] == 1: @@ -483,7 +498,9 @@ def print_model_summary( print("Failed to generate model summary with error: ", e) -def get_ground_truths_and_predictions_tensor(params, loader_type): +def get_ground_truths_and_predictions_tensor( + params: dict, loader_type: str +) -> Tuple[torch.Tensor, torch.Tensor]: """ This function is used to get the ground truths and predictions for a given loader type. @@ -492,7 +509,7 @@ def get_ground_truths_and_predictions_tensor(params, loader_type): loader_type (str): The loader type for which the ground truths and predictions are to be returned. Returns: - torch.Tensor, torch.Tensor: The ground truths and base predictions for the given loader type. + Tuple[torch.Tensor, torch.Tensor]: The ground truths and predictions for the given loader type. """ ground_truth_array = torch.from_numpy( params[loader_type][ @@ -506,7 +523,9 @@ def get_ground_truths_and_predictions_tensor(params, loader_type): return ground_truth_array, predictions_array -def get_output_from_calculator(predictions, ground_truth, calculator): +def get_output_from_calculator( + prediction: torch.Tensor, target: torch.tensor, calculator: torchmetrics.Metric +) -> float: """ Helper function to get the output from a calculator. @@ -518,7 +537,7 @@ def get_output_from_calculator(predictions, ground_truth, calculator): Returns: float: The output from the calculator. """ - temp_output = calculator(predictions, ground_truth) + temp_output = calculator(prediction, target) if temp_output.dim() > 0: temp_output = temp_output.cpu().tolist() else: diff --git a/GANDLF/utils/write_parse.py b/GANDLF/utils/write_parse.py index 9673451e7..03ee27552 100644 --- a/GANDLF/utils/write_parse.py +++ b/GANDLF/utils/write_parse.py @@ -1,6 +1,7 @@ import os import pathlib import sys +from typing import Optional, Tuple, Union import pandas as pd @@ -8,16 +9,21 @@ def writeTrainingCSV( - inputDir, channelsID, labelID, outputFile, relativizePathsToOutput=False + inputDir: str, + channelsID: str, + labelID: str, + outputFile: str, + relativizePathsToOutput: Optional[bool] = False, ) -> None: """ - This function writes the CSV file based on the input directory, channelsID + labelsID strings + This function writes a CSV file containing the paths to the training data. Args: - inputDir (str): The input directory. - channelsID (str): The channel header(s) identifiers. - labelID (str): The label header identifier. - outputFile (str): The output files to write + inputDir (str): The input directory containing all the training data. + channelsID (str): The channel IDs. + labelID (str): The label ID. + outputFile (str): The output CSV file. + relativizePathsToOutput (Optional[bool], optional): Whether to relativize the paths to the output file. Defaults to False. """ channelsID_list = channelsID.split(",") # split into list @@ -64,17 +70,18 @@ def writeTrainingCSV( file.close() -def parseTrainingCSV(inputTrainingCSVFile, train=True) -> (pd.DataFrame, dict): +def parseTrainingCSV( + inputTrainingCSVFile: str, train: Optional[bool] = True +) -> Tuple[pd.DataFrame, dict]: """ This function parses the input training CSV and returns a dictionary of headers and the full (randomized) data frame Args: inputTrainingCSVFile (str): The input data CSV file which contains all training data. - train (bool, optional): Whether performing training. Defaults to True. + train (Optional[bool], optional): Whether to train the model. Defaults to True. Returns: - pandas.DataFrame: The full dataset for computation. - dict: The dictionary containing all relevant CSV headers. + Tuple[pd.DataFrame, dict]: The full dataset for computation and the dictionary containing all relevant CSV headers. """ ## read training dataset into data frame data_full = get_dataframe(inputTrainingCSVFile) @@ -125,20 +132,19 @@ def parseTrainingCSV(inputTrainingCSVFile, train=True) -> (pd.DataFrame, dict): return data_full, headers -def parseTestingCSV(inputTrainingCSVFile, output_dir) -> (bool, pd.DataFrame, dict): +def parseTestingCSV( + inputTrainingCSVFile, output_dir +) -> Tuple[bool, pd.DataFrame, dict]: """ - This function parses the input training CSV and returns a dictionary of headers and the full (randomized) data frame + This function parses the input testing CSV and returns a dictionary of headers and the full (randomized) data frame Args: - inputTrainingCSVFile (str): The input data CSV file which contains all training data. - train (bool, optional): Whether performing training. Defaults to True. + inputTrainingCSVFile (str): The input data CSV file which contains all testing data. + output_dir (str): The output directory for the updated_test_mapping.csv and the collision.csv. Returns: - bool: Whether collisions were found or not. - pandas.DataFrame: The full dataset for computation. - dict: The dictionary containing all relevant CSV headers. + Tuple[bool, pd.DataFrame, dict]: A boolean indicating whether any collisions were found, the full dataset for computation, and the dictionary containing all relevant CSV headers. """ - data_full, headers = parseTrainingCSV(inputTrainingCSVFile, train=False) collision_status, data_full = handle_collisions(data_full, headers, output_dir) @@ -159,7 +165,7 @@ def parseTestingCSV(inputTrainingCSVFile, output_dir) -> (bool, pd.DataFrame, di return collision_status, data_full, headers -def get_dataframe(input_file) -> pd.DataFrame: +def get_dataframe(input_file: Union[str, pd.DataFrame]) -> pd.DataFrame: """ This function parses the input and returns a data frame @@ -182,7 +188,7 @@ def get_dataframe(input_file) -> pd.DataFrame: def convert_relative_paths_in_dataframe( - input_dataframe, headers, path_root + input_dataframe: pd.DataFrame, headers: dict, path_root: str ) -> pd.DataFrame: """ This function takes a dataframe containing paths and a root path (usually to a data CSV file).