Skip to content

Commit

Permalink
Merge branch 'master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
strasserpatrick committed Apr 28, 2024
2 parents 2c97db3 + 5db9604 commit a6b2b47
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 53 deletions.
4 changes: 2 additions & 2 deletions nnunetv2/inference/data_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(self, list_of_lists: List[List[str]],

def generate_train_batch(self):
idx = self.get_indices()[0]
files, seg_prev_stage, ofile = self._data[idx][0]
files, seg_prev_stage, ofile = self._data[idx]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
Expand Down Expand Up @@ -190,7 +190,7 @@ def __init__(self, list_of_images: List[np.ndarray],

def generate_train_batch(self):
idx = self.get_indices()[0]
image, seg_prev_stage, props, ofname = self._data[idx][0]
image, seg_prev_stage, props, ofname = self._data[idx]
# if we have a segmentation from the previous stage we have to process it together with the images so that we
# can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
# preprocessing and then there might be misalignments
Expand Down
4 changes: 3 additions & 1 deletion nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')

if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
Expand Down
13 changes: 7 additions & 6 deletions nnunetv2/preprocessing/preprocessors/default_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@
import multiprocessing
import shutil
from time import sleep
from typing import Union, Tuple
from typing import Tuple, Union

import nnunetv2
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from tqdm import tqdm

import nnunetv2
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
create_lists_from_splitted_dataset_folder, get_filenames_of_train_images_and_targets
from tqdm import tqdm
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets


class DefaultPreprocessor(object):
Expand All @@ -41,7 +41,7 @@ def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], propertie
plans_manager: PlansManager, configuration_manager: ConfigurationManager,
dataset_json: Union[dict, str]):
# let's not mess up the inputs!
data = np.copy(data)
data = data.astype(np.float32) # this creates a copy
if seg is not None:
assert data.shape[1:] == seg.shape[1:], "Shape mismatch between image and segmentation. Please fix your dataset and make use of the --verify_dataset_integrity flag to ensure everything is correct"
seg = np.copy(seg)
Expand Down Expand Up @@ -252,6 +252,7 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan
'reducing the number of workers might help')
done = [i for i in remaining if r[i].ready()]
for _ in done:
r[_].get() # allows triggering errors
pbar.update()
remaining = [i for i in remaining if i not in done]
sleep(0.1)
Expand Down
44 changes: 18 additions & 26 deletions nnunetv2/preprocessing/resampling/default_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],

def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,
do_separate_z: bool = False, order_z: int = 0):
do_separate_z: bool = False, order_z: int = 0, dtype_out = None):
"""
separate_z=True will resample with order 0 along z
:param data:
Expand All @@ -145,11 +145,13 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
else:
resize_fn = resize
kwargs = {'mode': 'edge', 'anti_aliasing': False}
dtype_data = data.dtype
shape = np.array(data[0].shape)
new_shape = np.array(new_shape)
if dtype_out is None:
dtype_out = data.dtype
reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out)
if np.any(shape != new_shape):
data = data.astype(float)
data = data.astype(float, copy=False)
if do_separate_z:
# print("separate z, order in z is", order_z, "order inplane is", order)
assert len(axis) == 1, "only one anisotropic axis supported"
Expand All @@ -161,22 +163,20 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L
else:
new_shape_2d = new_shape[:-1]

reshaped_final_data = []
for c in range(data.shape[0]):
reshaped_data = []
reshaped_here = np.zeros((data.shape[1], *new_shape_2d))
for slice_id in range(shape[axis]):
if axis == 0:
reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs))
reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)
elif axis == 1:
reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs))
reshaped_here[slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)
else:
reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs))
reshaped_data = np.stack(reshaped_data, axis)
reshaped_here[slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)
if shape[axis] != new_shape[axis]:

# The following few lines are blatantly copied and modified from sklearn's resize()
rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
orig_rows, orig_cols, orig_dim = reshaped_data.shape
orig_rows, orig_cols, orig_dim = reshaped_here.shape

row_scale = float(orig_rows) / rows
col_scale = float(orig_cols) / cols
Expand All @@ -189,28 +189,20 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L

coord_map = np.array([map_rows, map_cols, map_dims])
if not is_seg or order_z == 0:
reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z,
mode='nearest')[None])
reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None]
else:
unique_labels = np.sort(pd.unique(reshaped_data.ravel())) # np.unique(reshaped_data)
reshaped = np.zeros(new_shape, dtype=dtype_data)

unique_labels = np.sort(pd.unique(reshaped_here.ravel())) # np.unique(reshaped_data)
for i, cl in enumerate(unique_labels):
reshaped_multihot = np.round(
map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z,
mode='nearest'))
reshaped[reshaped_multihot > 0.5] = cl
reshaped_final_data.append(reshaped[None])
reshaped_final[c][np.round(
map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z,
mode='nearest')) > 0.5] = cl
else:
reshaped_final_data.append(reshaped_data[None])
reshaped_final_data = np.vstack(reshaped_final_data)
reshaped_final[c] = reshaped_here
else:
# print("no separate z, order", order)
reshaped = []
for c in range(data.shape[0]):
reshaped.append(resize_fn(data[c], new_shape, order, **kwargs)[None])
reshaped_final_data = np.vstack(reshaped)
return reshaped_final_data.astype(dtype_data)
reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
return reshaped_final
else:
# print("no resampling necessary")
return data
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,11 @@ def __init__(self, regions: Union[List, Tuple],

def __call__(self, **data_dict):
seg = data_dict.get(self.seg_key)
num_regions = len(self.regions)
if seg is not None:
seg_shp = seg.shape
output_shape = list(seg_shp)
output_shape[1] = num_regions
region_output = np.zeros(output_shape, dtype=seg.dtype)
for b in range(seg_shp[0]):
for region_id, region_source_labels in enumerate(self.regions):
if not isinstance(region_source_labels, (list, tuple)):
region_source_labels = (region_source_labels, )
for label_value in region_source_labels:
region_output[b, region_id][seg[b, self.seg_channel] == label_value] = 1
data_dict[self.output_key] = region_output
b, c, *shape = seg.shape
region_output = np.zeros((b, len(self.regions), *shape), dtype=bool)
for region_id, region_labels in enumerate(self.regions):
region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
data_dict[self.output_key] = region_output.astype(np.uint8, copy=False)
return data_dict

4 changes: 0 additions & 4 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,19 +1291,15 @@ def run_training(self):

self.on_train_epoch_start()
train_outputs = []
st = time()
for batch_id in range(self.num_iterations_per_epoch):
train_outputs.append(self.train_step(next(self.dataloader_train)))
print('train time', time() - st)
self.on_train_epoch_end(train_outputs)

with torch.no_grad():
self.on_validation_epoch_start()
val_outputs = []
st = time()
for batch_id in range(self.num_val_iterations_per_epoch):
val_outputs.append(self.validation_step(next(self.dataloader_val)))
print('val time', time() - st)
self.on_validation_epoch_end(val_outputs)

self.on_epoch_end()
Expand Down
2 changes: 1 addition & 1 deletion nnunetv2/utilities/find_class_by_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ def recursive_find_python_class(folder: str, class_name: str, current_module: st
tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module)
if tr is not None:
break
return tr
return tr
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nnunetv2"
version = "2.4"
version = "2.4.2"
requires-python = ">=3.9"
description = "nnU-Net is a framework for out-of-the box image segmentation."
readme = "readme.md"
Expand Down

0 comments on commit a6b2b47

Please sign in to comment.