From e9066b21e3414a61b2d68169773f6403d5f70b85 Mon Sep 17 00:00:00 2001 From: vahluw Date: Fri, 16 Aug 2024 16:59:27 -0400 Subject: [PATCH 1/2] Fixed bug in compute/forward_pass.py line 359 that caused error with 2D patches --- GANDLF/compute/forward_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index b0131c3a5..b254fd5ec 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -356,7 +356,7 @@ def validate_network( if image.shape[-1] > 1: result_image = sitk.GetImageFromArray(pred_mask) else: - result_image = sitk.GetImageFromArray(pred_mask.squeeze(0)) + result_image = sitk.GetImageFromArray(pred_mask.squeeze(-1)) result_image.CopyInformation(img_for_metadata) # this handles cases that need resampling/resizing From 258cd900fc05047aa97966f79549db98e286a70c Mon Sep 17 00:00:00 2001 From: vahluw Date: Mon, 19 Aug 2024 13:30:17 -0400 Subject: [PATCH 2/2] Requested change pull request #922 --- GANDLF/compute/forward_pass.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/GANDLF/compute/forward_pass.py b/GANDLF/compute/forward_pass.py index e077cd52c..69efa15a9 100644 --- a/GANDLF/compute/forward_pass.py +++ b/GANDLF/compute/forward_pass.py @@ -337,11 +337,16 @@ def validate_network( if ext in [".jpg", ".jpeg", ".png"]: pred_mask = pred_mask.astype(np.uint8) - ## special case for 2D - if image.shape[-1] > 1: - result_image = sitk.GetImageFromArray(pred_mask) - else: - result_image = sitk.GetImageFromArray(pred_mask.squeeze(-1)) + pred_mask = ( + pred_mask.squeeze(0) + if pred_mask.shape[0] == 1 + else ( + pred_mask.squeeze(-1) + if pred_mask.shape[-1] == 1 + else pred_mask + ) + ) + result_image = sitk.GetImageFromArray(pred_mask) result_image.CopyInformation(img_for_metadata) # this handles cases that need resampling/resizing