From 1d492aa87a9cd284e87764577c15226e20563225 Mon Sep 17 00:00:00 2001 From: FelixSteinbauer <43598998+FelixSteinbauer@users.noreply.github.com> Date: Thu, 10 Aug 2023 16:21:29 +0200 Subject: [PATCH 1/3] Use tuple for data_range in PSNR instead of range size --- GANDLF/metrics/synthesis.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index fd3590a58..3b56e61b0 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -55,7 +55,7 @@ def peak_signal_noise_ratio(target, prediction, data_range=None, epsilon=None) - Args: target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. - data_range (float, optional): If not None, this data range is used as enumerator instead of computing it from the given data. Defaults to None. + data_range (tuple, optional): If not None, this data range (min, max) is used as enumerator instead of computing it from the given data. Defaults to None. epsilon (float, optional): If not None, this epsilon is added to the denominator of the fraction to avoid infinity as output. Defaults to None. """ @@ -67,8 +67,9 @@ def peak_signal_noise_ratio(target, prediction, data_range=None, epsilon=None) - if data_range == None: #compute data_range like torchmetrics if not given min_v = 0 if torch.min(target) > 0 else torch.min(target) #look at this line max_v = torch.max(target) - data_range = max_v - min_v - return 10.0 * torch.log10((data_range ** 2) / (mse + epsilon)) + else: + min_v, max_v = data_range + return 10.0 * torch.log10(((max_v-min_v) ** 2) / (mse + epsilon)) def mean_squared_log_error(target, prediction) -> torch.Tensor: From 5690df784311a619b900418d14b1d1e8941dcf2e Mon Sep 17 00:00:00 2001 From: FelixSteinbauer <43598998+FelixSteinbauer@users.noreply.github.com> Date: Thu, 10 Aug 2023 16:31:00 +0200 Subject: [PATCH 2/3] Use tuple for datat range & fixed the missing range in the psnr_01_eps case --- GANDLF/cli/generate_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index 07423a3fc..aa47e5f69 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -301,14 +301,14 @@ def __percentile_clip(input_tensor, reference_tensor=None, p_min=0.5, p_max=99.5 overall_stats_dict[current_subject_id][ "psnr_01" ] = peak_signal_noise_ratio( - gt_image_infill, output_infill, data_range=1.0 + gt_image_infill, output_infill, data_range=(0,1) ).item() # same as above but with epsilon for robustness overall_stats_dict[current_subject_id][ "psnr_01_eps" ] = peak_signal_noise_ratio( - gt_image_infill, output_infill, epsilon=sys.float_info.epsilon + gt_image_infill, output_infill, data_range=(0,1), epsilon=sys.float_info.epsilon ).item() pprint(overall_stats_dict) From 24d175f10cb95f87470bff8791325c7a261f246d Mon Sep 17 00:00:00 2001 From: FelixSteinbauer <43598998+FelixSteinbauer@users.noreply.github.com> Date: Thu, 10 Aug 2023 17:22:00 +0200 Subject: [PATCH 3/3] Added range computation for the torchmetrics call --- GANDLF/metrics/synthesis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index 3b56e61b0..9aa4da466 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -60,7 +60,7 @@ def peak_signal_noise_ratio(target, prediction, data_range=None, epsilon=None) - """ if epsilon == None: - psnr = PeakSignalNoiseRatio(data_range=data_range) + psnr = PeakSignalNoiseRatio() if data_range == None else PeakSignalNoiseRatio(data_range=data_range[1]-data_range[0]) return psnr(preds=prediction, target=target) else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0' mse = mean_squared_error(target, prediction)