Skip to content

Commit

Permalink
added torch resampling fn
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Apr 30, 2024
1 parent ef1b4c9 commit 8c4a9d0
Showing 1 changed file with 16 additions and 35 deletions.
51 changes: 16 additions & 35 deletions nnunetv2/preprocessing/resampling/default_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
return new_shape


def resample_data_or_seg_to_spacing(data: np.ndarray,
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
order: int = 3, order_z: int = 0,
force_separate_z: Union[bool, None] = False,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
def determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
if force_separate_z is not None:
do_separate_z = force_separate_z
if force_separate_z:
Expand All @@ -55,14 +49,25 @@ def resample_data_or_seg_to_spacing(data: np.ndarray,

if axis is not None:
if len(axis) == 3:
# every axis has the same spacing, this should never happen, why is this code here?
do_separate_z = False
elif len(axis) == 2:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
else:
pass
return do_separate_z, axis


def resample_data_or_seg_to_spacing(data: np.ndarray,
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
order: int = 3, order_z: int = 0,
force_separate_z: Union[bool, None] = False,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
separate_z_anisotropy_threshold)

if data is not None:
assert data.ndim == 4, "data must be c x y z"
Expand All @@ -86,34 +91,10 @@ def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
needed for segmentation export. Stupid, I know
"""
if isinstance(data, torch.Tensor):
data = data.cpu().numpy()
if force_separate_z is not None:
do_separate_z = force_separate_z
if force_separate_z:
axis = get_lowres_axis(current_spacing)
else:
axis = None
else:
if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(current_spacing)
elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
do_separate_z = True
axis = get_lowres_axis(new_spacing)
else:
do_separate_z = False
axis = None
data = data.numpy()

if axis is not None:
if len(axis) == 3:
# every axis has the same spacing, this should never happen, why is this code here?
do_separate_z = False
elif len(axis) == 2:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
else:
pass
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
separate_z_anisotropy_threshold)

if data is not None:
assert data.ndim == 4, "data must be c x y z"
Expand Down

0 comments on commit 8c4a9d0

Please sign in to comment.