Skip to content

Commit

Permalink
streamling parametrixation of aniso axis
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Jun 7, 2024
1 parent a737753 commit 9ffe0a5
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions nnunetv2/preprocessing/resampling/default_resampling.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections import OrderedDict
from copy import deepcopy
from typing import Union, Tuple, List

import numpy as np
import pandas as pd
import sklearn
import torch
from batchgenerators.augmentations.utils import resize_segmentation
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage import map_coordinates
from skimage.transform import resize
from nnunetv2.configuration import ANISO_THRESHOLD

Expand All @@ -29,7 +31,11 @@ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
return new_shape


def determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
def determine_do_sep_z_and_axis(
force_separate_z: bool,
current_spacing,
new_spacing,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]:
if force_separate_z is not None:
do_separate_z = force_separate_z
if force_separate_z:
Expand All @@ -50,12 +56,14 @@ def determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
if axis is not None:
if len(axis) == 3:
do_separate_z = False
axis = None
elif len(axis) == 2:
# this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
# separately in the out of plane axis
do_separate_z = False
axis = None
else:
pass
axis = axis[0]
return do_separate_z, axis


Expand Down Expand Up @@ -135,8 +143,7 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
data = data.astype(float, copy=False)
if do_separate_z:
# print("separate z, order in z is", order_z, "order inplane is", order)
assert len(axis) == 1, "only one anisotropic axis supported"
axis = axis[0]
assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic'
if axis == 0:
new_shape_2d = new_shape[1:]
elif axis == 1:
Expand All @@ -145,20 +152,23 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
new_shape_2d = new_shape[:-1]

for c in range(data.shape[0]):
reshaped_here = np.zeros((data.shape[1], *new_shape_2d))
tmp = deepcopy(new_shape)
tmp[axis] = shape[axis]
reshaped_here = np.zeros(tmp)
for slice_id in range(shape[axis]):
if axis == 0:
reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)
elif axis == 1:
reshaped_here[slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)
reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)
else:
reshaped_here[slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)
reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)
if shape[axis] != new_shape[axis]:

# The following few lines are blatantly copied and modified from sklearn's resize()
rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
orig_rows, orig_cols, orig_dim = reshaped_here.shape

# align_corners=False
row_scale = float(orig_rows) / rows
col_scale = float(orig_cols) / cols
dim_scale = float(orig_dim) / dim
Expand Down Expand Up @@ -187,3 +197,10 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
else:
# print("no resampling necessary")
return data


if __name__ == '__main__':
input_array = np.random.random((1, 42, 231, 142))
output_shape = (52, 256, 256)
out = resample_data_or_seg(input_array, output_shape, is_seg=False, axis=3, order=1, order_z=0, do_separate_z=True)
print(out.shape, input_array.shape)

0 comments on commit 9ffe0a5

Please sign in to comment.