Skip to content

Commit

Permalink
Merge branch 'master' into assert_messages_flush_to_log_file
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati authored Aug 19, 2024
2 parents 421a944 + e9d92ae commit 69a86e3
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,16 @@ def validate_network(
if ext in [".jpg", ".jpeg", ".png"]:
pred_mask = pred_mask.astype(np.uint8)

## special case for 2D
if image.shape[-1] > 1:
result_image = sitk.GetImageFromArray(pred_mask)
else:
result_image = sitk.GetImageFromArray(pred_mask.squeeze(0))
pred_mask = (
pred_mask.squeeze(0)
if pred_mask.shape[0] == 1
else (
pred_mask.squeeze(-1)
if pred_mask.shape[-1] == 1
else pred_mask
)
)
result_image = sitk.GetImageFromArray(pred_mask)
result_image.CopyInformation(img_for_metadata)

# this handles cases that need resampling/resizing
Expand Down

0 comments on commit 69a86e3

Please sign in to comment.