diff --git a/GANDLF/data/patch_miner/opm/utils.py b/GANDLF/data/patch_miner/opm/utils.py index 9c66118a8..53bf3605b 100644 --- a/GANDLF/data/patch_miner/opm/utils.py +++ b/GANDLF/data/patch_miner/opm/utils.py @@ -10,8 +10,9 @@ from skimage.morphology import remove_small_holes from skimage.color.colorconv import rgb2hsv import cv2 -#from skimage.exposure import rescale_intensity -#from skimage.color import rgb2hed + +# from skimage.exposure import rescale_intensity +# from skimage.color import rgb2hed # import matplotlib.pyplot as plt import yaml @@ -238,7 +239,8 @@ def alpha_rgb_2d_channel_check(img): else: return False -#def pen_marking_check(img, intensity_thresh=225, intensity_thresh_saturation =50, intensity_thresh_b = 128): + +# def pen_marking_check(img, intensity_thresh=225, intensity_thresh_saturation =50, intensity_thresh_b = 128): # """ # This function is used to curate patches from the input image. It is used to remove patches that have pen markings. # Args: @@ -259,7 +261,14 @@ def alpha_rgb_2d_channel_check(img): # #Assume patch is valid # return True -def patch_artifact_check(img, intensity_thresh = 250, intensity_thresh_saturation = 5, intensity_thresh_b = 128, patch_size = (256,256)): + +def patch_artifact_check( + img, + intensity_thresh=250, + intensity_thresh_saturation=5, + intensity_thresh_b=128, + patch_size=(256, 256), +): """ This function is used to curate patches from the input image. It is used to remove patches that are mostly background. Args: @@ -271,23 +280,36 @@ def patch_artifact_check(img, intensity_thresh = 250, intensity_thresh_saturatio Returns: bool: Whether the patch is valid (True) or not (False) """ - #patch_size = config["patch_size"] + # patch_size = config["patch_size"] patch_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) count_white_pixels = np.sum(np.logical_and.reduce(img > intensity_thresh, axis=2)) percent_pixels = count_white_pixels / (patch_size[0] * patch_size[1]) count_black_pixels = np.sum(np.logical_and.reduce(img < intensity_thresh_b, axis=2)) percent_pixel_b = count_black_pixels / (patch_size[0] * patch_size[1]) - percent_pixel_2 = np.sum(patch_hsv[...,1] < intensity_thresh_saturation) / (patch_size[0] * patch_size[1]) - percent_pixel_3 = np.sum(patch_hsv[...,2] > intensity_thresh) / (patch_size[0] * patch_size[1]) + percent_pixel_2 = np.sum(patch_hsv[..., 1] < intensity_thresh_saturation) / ( + patch_size[0] * patch_size[1] + ) + percent_pixel_3 = np.sum(patch_hsv[..., 2] > intensity_thresh) / ( + patch_size[0] * patch_size[1] + ) - if percent_pixel_2 > 0.99 or np.mean(patch_hsv[...,1]) < 5 or percent_pixel_3 > 0.99: + if ( + percent_pixel_2 > 0.99 + or np.mean(patch_hsv[..., 1]) < 5 + or percent_pixel_3 > 0.99 + ): if percent_pixel_2 < 0.1: return False - elif (percent_pixel_2 > 0.99 and percent_pixel_3 > 0.99) or percent_pixel_b > 0.99 or percent_pixels > 0.9: + elif ( + (percent_pixel_2 > 0.99 and percent_pixel_3 > 0.99) + or percent_pixel_b > 0.99 + or percent_pixels > 0.9 + ): return False # assume that the patch is valid return True + def parse_config(config_file): """ Parse config file and return a dictionary of config values. @@ -304,7 +326,7 @@ def parse_config(config_file): config["value_map"] = config.get("value_map", None) config["read_type"] = config.get("read_type", "random") config["overlap_factor"] = config.get("overlap_factor", 0.0) - config["patch_size"] = config.get("patch_size", [256,256]) + config["patch_size"] = config.get("patch_size", [256, 256]) return config diff --git a/GANDLF/metrics/classification.py b/GANDLF/metrics/classification.py index 842dc88f4..ec4e9d160 100644 --- a/GANDLF/metrics/classification.py +++ b/GANDLF/metrics/classification.py @@ -64,9 +64,7 @@ def overall_stats(predictions, ground_truth, params): "aucroc": tm.AUROC( task=task, num_classes=params["model"]["num_classes"], - average=average_type_key - if average_type_key != "micro" - else "macro", + average=average_type_key if average_type_key != "micro" else "macro", ), } for metric_name, calculator in calculators.items(): diff --git a/GANDLF/metrics/generic.py b/GANDLF/metrics/generic.py index 34b197df7..1c90a3a20 100644 --- a/GANDLF/metrics/generic.py +++ b/GANDLF/metrics/generic.py @@ -15,9 +15,7 @@ ) -def generic_function_output_with_check( - predicted_classes, label, metric_function -): +def generic_function_output_with_check(predicted_classes, label, metric_function): if torch.min(predicted_classes) < 0: print( "WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations." @@ -32,16 +30,12 @@ def generic_function_output_with_check( max_clamp_val = metric_function.num_classes - 1 except AttributeError: max_clamp_val = 1 - predicted_new = torch.clamp( - predicted_classes.cpu().int(), max=max_clamp_val - ) + predicted_new = torch.clamp(predicted_classes.cpu().int(), max=max_clamp_val) predicted_new = predicted_new.reshape(label.shape) return metric_function(predicted_new, label.cpu().int()) -def generic_torchmetrics_score( - output, label, metric_class, metric_key, params -): +def generic_torchmetrics_score(output, label, metric_class, metric_key, params): task = determine_classification_task_type(params) num_classes = params["model"]["num_classes"] predicted_classes = output @@ -67,9 +61,7 @@ def recall_score(output, label, params): def precision_score(output, label, params): - return generic_torchmetrics_score( - output, label, Precision, "precision", params - ) + return generic_torchmetrics_score(output, label, Precision, "precision", params) def f1_score(output, label, params): @@ -77,15 +69,11 @@ def f1_score(output, label, params): def accuracy(output, label, params): - return generic_torchmetrics_score( - output, label, Accuracy, "accuracy", params - ) + return generic_torchmetrics_score(output, label, Accuracy, "accuracy", params) def specificity_score(output, label, params): - return generic_torchmetrics_score( - output, label, Specificity, "specificity", params - ) + return generic_torchmetrics_score(output, label, Specificity, "specificity", params) def iou_score(output, label, params): diff --git a/GANDLF/utils/generic.py b/GANDLF/utils/generic.py index 050cce4ff..52f6a33d9 100644 --- a/GANDLF/utils/generic.py +++ b/GANDLF/utils/generic.py @@ -49,7 +49,9 @@ def checkPatchDivisibility(patch_size, number=16): return True -def determine_classification_task_type(params: Dict[str, Union[Dict[str, Any], Any]]) -> str: +def determine_classification_task_type( + params: Dict[str, Union[Dict[str, Any], Any]] +) -> str: """Determine the task (binary or multiclass) from the model config. Args: params (dict): The parameter dictionary containing training and data information. @@ -159,10 +161,7 @@ def checkPatchDimensions(patch_size, numlay): patch_size_to_check = patch_size_to_check[:-1] if all( - [ - x >= 2 ** (numlay + 1) and x % 2**numlay == 0 - for x in patch_size_to_check - ] + [x >= 2 ** (numlay + 1) and x % 2**numlay == 0 for x in patch_size_to_check] ): return numlay else: @@ -198,9 +197,7 @@ def get_array_from_image_or_tensor(input_tensor_or_image): elif isinstance(input_tensor_or_image, np.ndarray): return input_tensor_or_image else: - raise ValueError( - "Input must be a torch.Tensor or sitk.Image or np.ndarray" - ) + raise ValueError("Input must be a torch.Tensor or sitk.Image or np.ndarray") def set_determinism(seed=42): @@ -270,9 +267,7 @@ def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict: output_metrics_dict = deepcopy(cohort_level_metrics) for metric in metrics_dict_from_parameters: if isinstance(sample_level_metrics[metric], np.ndarray): - to_print = ( - sample_level_metrics[metric] / length_of_dataloader - ).tolist() + to_print = (sample_level_metrics[metric] / length_of_dataloader).tolist() else: to_print = sample_level_metrics[metric] / length_of_dataloader output_metrics_dict[metric] = to_print @@ -315,7 +310,5 @@ def define_multidim_average_type_key(params, metric_name) -> str: Returns: str: The average type key. """ - average_type_key = params["metrics"][metric_name].get( - "multidim_average", "global" - ) + average_type_key = params["metrics"][metric_name].get("multidim_average", "global") return average_type_key