Skip to content

Commit

Permalink
Introduced percentile normalization for synthesis challenge metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixSteinbauer authored Jul 29, 2023
1 parent d1fc9e2 commit 941c85b
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,32 @@ def __fix_2d_tensor(input_tensor):
else:
return input_tensor

def __percentile_clip(input_tensor, reference_tensor=None, p_min=0.5, p_max=99.5, strictlyPositive=True):
"""Normalizes a tensor based on percentiles. Clips values below and above the percentile.
Percentiles for normalization can come from another tensor.
Args:
input_tensor (torch.Tensor): Tensor to be normalized based on the data from the reference_tensor.
If reference_tensor is None, the percentiles from this tensor will be used.
reference_tensor (torch.Tensor, optional): The tensor used for obtaining the percentiles.
p_min (float, optional): Lower end percentile. Defaults to 0.5.
p_max (float, optional): Upper end percentile. Defaults to 99.5.
strictlyPositive (bool, optional): Ensures that really all values are above 0 before normalization. Defaults to True.
Returns:
torch.Tensor: The input_tensor normalized based on the percentiles of the reference tensor.
"""
if(reference_tensor == None):
reference_tensor = input_tensor
v_min, v_max = np.percentile(reference_tensor, [p_min,p_max]) #get p_min percentile and p_max percentile

if( v_min < 0 and strictlyPositive): #set lower bound to be 0 if it would be below
v_min = 0
output_tensor = np.clip(input_tensor,v_min,v_max) #clip values to percentiles from reference_tensor
output_tensor = (output_tensor - v_min)/(v_max-v_min) #normalizes values to [0;1]

return output_tensor

for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
current_subject_id = row[headers["subjectid"]]
overall_stats_dict[current_subject_id] = {}
Expand All @@ -219,9 +245,9 @@ def __fix_2d_tensor(input_tensor):
# Normalize to [0;1] based on GT (otherwise MSE will depend on the image intensity range)
normalize = parameters.get("normalize", True)
if normalize:
v_max = gt_image_infill.max()
output_infill /= v_max
gt_image_infill /= v_max
reference_tensor = gt_image * ~mask #use all the tissue that is not masked for normalization
gt_image_infill = __percentile_clip(gt_image_infill, reference_tensor=reference_tensor, p_min=0.5, p_max=99.5, strictlyPositive=True)
prediction_infill = __percentile_clip(prediction_infill, reference_tensor=reference_tensor, p_min=0.5, p_max=99.5, strictlyPositive=True)

overall_stats_dict[current_subject_id][
"ssim"
Expand Down Expand Up @@ -258,6 +284,7 @@ def __fix_2d_tensor(input_tensor):
gt_image_infill, output_infill
).item()

#TODO: use data_range=1.0 as parameter for PSNR when the Pull request is accepted that introduces the data_range parameter!
# PSNR - similar to pytorch PeakSignalNoiseRatio until 4 digits after decimal point
overall_stats_dict[current_subject_id]["psnr"] = peak_signal_noise_ratio(
gt_image_infill, output_infill
Expand Down

0 comments on commit 941c85b

Please sign in to comment.