Skip to content

Commit

Permalink
Used proper torchmetric parameter order and replaced own PSNR impleme…
Browse files Browse the repository at this point in the history
…ntation with torchmetric PSNR

- I replaced the self-implemented PSNR computation with the one provided by torchmetric. 
- The ordering of torchmetric function call arguments is actually predictions ("preds") and then target ("target"), not the other way around.
  • Loading branch information
FelixSteinbauer authored Jul 18, 2023
1 parent 8a8e437 commit 5409c32
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions GANDLF/metrics/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MeanSquaredError,
MeanSquaredLogError,
MeanAbsoluteError,
PeakSignalNoiseRatio,
)
from GANDLF.utils import get_image_from_tensor

Expand All @@ -25,7 +26,7 @@ def structural_similarity_index(target, prediction, mask=None) -> torch.Tensor:
torch.Tensor: The structural similarity index.
"""
ssim = StructuralSimilarityIndexMeasure(return_full_image=True)
_, ssim_idx_full_image = ssim(target, prediction)
_, ssim_idx_full_image = ssim(preds=prediction, target=target)
mask = torch.ones_like(ssim_idx_full_image) if mask is None else mask
try:
ssim_idx = ssim_idx_full_image[mask]
Expand All @@ -45,7 +46,7 @@ def mean_squared_error(target, prediction) -> torch.Tensor:
prediction (torch.Tensor): The prediction tensor.
"""
mse = MeanSquaredError()
return mse(target, prediction)
return mse(preds=prediction, target=target)


def peak_signal_noise_ratio(target, prediction) -> torch.Tensor:
Expand All @@ -56,12 +57,8 @@ def peak_signal_noise_ratio(target, prediction) -> torch.Tensor:
target (torch.Tensor): The target tensor.
prediction (torch.Tensor): The prediction tensor.
"""
mse = mean_squared_error(target, prediction)
return (
10.0
* torch.log10((torch.max(target) - torch.min(target)) ** 2)
/ (mse + sys.float_info.epsilon)
)
psnr = PeakSignalNoiseRatio()
return psnr(preds=prediction, target=target)


def mean_squared_log_error(target, prediction) -> torch.Tensor:
Expand All @@ -73,7 +70,7 @@ def mean_squared_log_error(target, prediction) -> torch.Tensor:
prediction (torch.Tensor): The prediction tensor.
"""
mle = MeanSquaredLogError()
return mle(target, prediction)
return mle(preds=prediction, target=target)


def mean_absolute_error(target, prediction) -> torch.Tensor:
Expand All @@ -85,7 +82,7 @@ def mean_absolute_error(target, prediction) -> torch.Tensor:
prediction (torch.Tensor): The prediction tensor.
"""
mae = MeanAbsoluteError()
return mae(target, prediction)
return mae(preds=prediction, target=target)


def _get_ncc_image(target, prediction) -> sitk.Image:
Expand Down

0 comments on commit 5409c32

Please sign in to comment.