Skip to content

Commit

Permalink
Merge pull request #746 from scap3yvt/scap3yvt-patch-black
Browse files Browse the repository at this point in the history
Consistent black across the codebase
  • Loading branch information
sarthakpati authored Nov 22, 2023
2 parents 10821bd + ec115db commit 3ef60ff
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 45 deletions.
42 changes: 32 additions & 10 deletions GANDLF/data/patch_miner/opm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from skimage.morphology import remove_small_holes
from skimage.color.colorconv import rgb2hsv
import cv2
#from skimage.exposure import rescale_intensity
#from skimage.color import rgb2hed

# from skimage.exposure import rescale_intensity
# from skimage.color import rgb2hed

# import matplotlib.pyplot as plt
import yaml
Expand Down Expand Up @@ -238,7 +239,8 @@ def alpha_rgb_2d_channel_check(img):
else:
return False

#def pen_marking_check(img, intensity_thresh=225, intensity_thresh_saturation =50, intensity_thresh_b = 128):

# def pen_marking_check(img, intensity_thresh=225, intensity_thresh_saturation =50, intensity_thresh_b = 128):
# """
# This function is used to curate patches from the input image. It is used to remove patches that have pen markings.
# Args:
Expand All @@ -259,7 +261,14 @@ def alpha_rgb_2d_channel_check(img):
# #Assume patch is valid
# return True

def patch_artifact_check(img, intensity_thresh = 250, intensity_thresh_saturation = 5, intensity_thresh_b = 128, patch_size = (256,256)):

def patch_artifact_check(
img,
intensity_thresh=250,
intensity_thresh_saturation=5,
intensity_thresh_b=128,
patch_size=(256, 256),
):
"""
This function is used to curate patches from the input image. It is used to remove patches that are mostly background.
Args:
Expand All @@ -271,23 +280,36 @@ def patch_artifact_check(img, intensity_thresh = 250, intensity_thresh_saturatio
Returns:
bool: Whether the patch is valid (True) or not (False)
"""
#patch_size = config["patch_size"]
# patch_size = config["patch_size"]
patch_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
count_white_pixels = np.sum(np.logical_and.reduce(img > intensity_thresh, axis=2))
percent_pixels = count_white_pixels / (patch_size[0] * patch_size[1])
count_black_pixels = np.sum(np.logical_and.reduce(img < intensity_thresh_b, axis=2))
percent_pixel_b = count_black_pixels / (patch_size[0] * patch_size[1])
percent_pixel_2 = np.sum(patch_hsv[...,1] < intensity_thresh_saturation) / (patch_size[0] * patch_size[1])
percent_pixel_3 = np.sum(patch_hsv[...,2] > intensity_thresh) / (patch_size[0] * patch_size[1])
percent_pixel_2 = np.sum(patch_hsv[..., 1] < intensity_thresh_saturation) / (
patch_size[0] * patch_size[1]
)
percent_pixel_3 = np.sum(patch_hsv[..., 2] > intensity_thresh) / (
patch_size[0] * patch_size[1]
)

if percent_pixel_2 > 0.99 or np.mean(patch_hsv[...,1]) < 5 or percent_pixel_3 > 0.99:
if (
percent_pixel_2 > 0.99
or np.mean(patch_hsv[..., 1]) < 5
or percent_pixel_3 > 0.99
):
if percent_pixel_2 < 0.1:
return False
elif (percent_pixel_2 > 0.99 and percent_pixel_3 > 0.99) or percent_pixel_b > 0.99 or percent_pixels > 0.9:
elif (
(percent_pixel_2 > 0.99 and percent_pixel_3 > 0.99)
or percent_pixel_b > 0.99
or percent_pixels > 0.9
):
return False
# assume that the patch is valid
return True


def parse_config(config_file):
"""
Parse config file and return a dictionary of config values.
Expand All @@ -304,7 +326,7 @@ def parse_config(config_file):
config["value_map"] = config.get("value_map", None)
config["read_type"] = config.get("read_type", "random")
config["overlap_factor"] = config.get("overlap_factor", 0.0)
config["patch_size"] = config.get("patch_size", [256,256])
config["patch_size"] = config.get("patch_size", [256, 256])

return config

Expand Down
4 changes: 1 addition & 3 deletions GANDLF/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def overall_stats(predictions, ground_truth, params):
"aucroc": tm.AUROC(
task=task,
num_classes=params["model"]["num_classes"],
average=average_type_key
if average_type_key != "micro"
else "macro",
average=average_type_key if average_type_key != "micro" else "macro",
),
}
for metric_name, calculator in calculators.items():
Expand Down
24 changes: 6 additions & 18 deletions GANDLF/metrics/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
)


def generic_function_output_with_check(
predicted_classes, label, metric_function
):
def generic_function_output_with_check(predicted_classes, label, metric_function):
if torch.min(predicted_classes) < 0:
print(
"WARNING: Negative values detected in prediction, cannot compute torchmetrics calculations."
Expand All @@ -32,16 +30,12 @@ def generic_function_output_with_check(
max_clamp_val = metric_function.num_classes - 1
except AttributeError:
max_clamp_val = 1
predicted_new = torch.clamp(
predicted_classes.cpu().int(), max=max_clamp_val
)
predicted_new = torch.clamp(predicted_classes.cpu().int(), max=max_clamp_val)
predicted_new = predicted_new.reshape(label.shape)
return metric_function(predicted_new, label.cpu().int())


def generic_torchmetrics_score(
output, label, metric_class, metric_key, params
):
def generic_torchmetrics_score(output, label, metric_class, metric_key, params):
task = determine_classification_task_type(params)
num_classes = params["model"]["num_classes"]
predicted_classes = output
Expand All @@ -67,25 +61,19 @@ def recall_score(output, label, params):


def precision_score(output, label, params):
return generic_torchmetrics_score(
output, label, Precision, "precision", params
)
return generic_torchmetrics_score(output, label, Precision, "precision", params)


def f1_score(output, label, params):
return generic_torchmetrics_score(output, label, F1Score, "f1", params)


def accuracy(output, label, params):
return generic_torchmetrics_score(
output, label, Accuracy, "accuracy", params
)
return generic_torchmetrics_score(output, label, Accuracy, "accuracy", params)


def specificity_score(output, label, params):
return generic_torchmetrics_score(
output, label, Specificity, "specificity", params
)
return generic_torchmetrics_score(output, label, Specificity, "specificity", params)


def iou_score(output, label, params):
Expand Down
21 changes: 7 additions & 14 deletions GANDLF/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def checkPatchDivisibility(patch_size, number=16):
return True


def determine_classification_task_type(params: Dict[str, Union[Dict[str, Any], Any]]) -> str:
def determine_classification_task_type(
params: Dict[str, Union[Dict[str, Any], Any]]
) -> str:
"""Determine the task (binary or multiclass) from the model config.
Args:
params (dict): The parameter dictionary containing training and data information.
Expand Down Expand Up @@ -159,10 +161,7 @@ def checkPatchDimensions(patch_size, numlay):
patch_size_to_check = patch_size_to_check[:-1]

if all(
[
x >= 2 ** (numlay + 1) and x % 2**numlay == 0
for x in patch_size_to_check
]
[x >= 2 ** (numlay + 1) and x % 2**numlay == 0 for x in patch_size_to_check]
):
return numlay
else:
Expand Down Expand Up @@ -198,9 +197,7 @@ def get_array_from_image_or_tensor(input_tensor_or_image):
elif isinstance(input_tensor_or_image, np.ndarray):
return input_tensor_or_image
else:
raise ValueError(
"Input must be a torch.Tensor or sitk.Image or np.ndarray"
)
raise ValueError("Input must be a torch.Tensor or sitk.Image or np.ndarray")


def set_determinism(seed=42):
Expand Down Expand Up @@ -270,9 +267,7 @@ def __update_metric_from_list_to_single_string(input_metrics_dict) -> dict:
output_metrics_dict = deepcopy(cohort_level_metrics)
for metric in metrics_dict_from_parameters:
if isinstance(sample_level_metrics[metric], np.ndarray):
to_print = (
sample_level_metrics[metric] / length_of_dataloader
).tolist()
to_print = (sample_level_metrics[metric] / length_of_dataloader).tolist()
else:
to_print = sample_level_metrics[metric] / length_of_dataloader
output_metrics_dict[metric] = to_print
Expand Down Expand Up @@ -315,7 +310,5 @@ def define_multidim_average_type_key(params, metric_name) -> str:
Returns:
str: The average type key.
"""
average_type_key = params["metrics"][metric_name].get(
"multidim_average", "global"
)
average_type_key = params["metrics"][metric_name].get("multidim_average", "global")
return average_type_key

0 comments on commit 3ef60ff

Please sign in to comment.