Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Apr 20, 2024
1 parent 9945333 commit 45a4be1
Showing 1 changed file with 17 additions and 25 deletions.
42 changes: 17 additions & 25 deletions nnunetv2/preprocessing/resampling/default_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],

def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,
do_separate_z: bool = False, order_z: int = 0):
do_separate_z: bool = False, order_z: int = 0, dtype_out = None):
"""
separate_z=True will resample with order 0 along z
:param data:
Expand All @@ -145,9 +145,11 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
else:
resize_fn = resize
kwargs = {'mode': 'edge', 'anti_aliasing': False}
dtype_data = data.dtype
shape = np.array(data[0].shape)
new_shape = np.array(new_shape)
if dtype_out is None:
dtype_out = data.dtype
reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out)
if np.any(shape != new_shape):
data = data.astype(float)
if do_separate_z:
Expand All @@ -161,22 +163,20 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
else:
new_shape_2d = new_shape[:-1]

reshaped_final_data = []
for c in range(data.shape[0]):
reshaped_data = []
reshaped_here = np.zeros((data.shape[1], *new_shape_2d))
for slice_id in range(shape[axis]):
if axis == 0:
reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs))
reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)
elif axis == 1:
reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs))
reshaped_here[slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)
else:
reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs))
reshaped_data = np.stack(reshaped_data, axis)
reshaped_here[slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)
if shape[axis] != new_shape[axis]:

# The following few lines are blatantly copied and modified from sklearn's resize()
rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
orig_rows, orig_cols, orig_dim = reshaped_data.shape
orig_rows, orig_cols, orig_dim = reshaped_here.shape

row_scale = float(orig_rows) / rows
col_scale = float(orig_cols) / cols
Expand All @@ -189,28 +189,20 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L

coord_map = np.array([map_rows, map_cols, map_dims])
if not is_seg or order_z == 0:
reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z,
mode='nearest')[None])
reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None]
else:
unique_labels = np.sort(pd.unique(reshaped_data.ravel())) # np.unique(reshaped_data)
reshaped = np.zeros(new_shape, dtype=dtype_data)

unique_labels = np.sort(pd.unique(reshaped_here.ravel())) # np.unique(reshaped_data)
for i, cl in enumerate(unique_labels):
reshaped_multihot = np.round(
map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z,
mode='nearest'))
reshaped[reshaped_multihot > 0.5] = cl
reshaped_final_data.append(reshaped[None])
reshaped_final[c][np.round(
map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z,
mode='nearest')) > 0.5] = cl
else:
reshaped_final_data.append(reshaped_data[None])
reshaped_final_data = np.vstack(reshaped_final_data)
reshaped_final[c] = reshaped_here
else:
# print("no separate z, order", order)
reshaped = []
for c in range(data.shape[0]):
reshaped.append(resize_fn(data[c], new_shape, order, **kwargs)[None])
reshaped_final_data = np.vstack(reshaped)
return reshaped_final_data.astype(dtype_data)
reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
return reshaped_final
else:
# print("no resampling necessary")
return data

0 comments on commit 45a4be1

Please sign in to comment.