Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding sequential inference #2248

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 95 additions & 23 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.helpers import empty_cache, dummy_context
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels, convert_labelmap_to_one_hot
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder

Expand Down Expand Up @@ -250,13 +250,87 @@ def predict_from_files(self,
if len(list_of_lists_or_source_folder) == 0:
return

if num_processes_preprocessing == 0 and num_processes_segmentation_export == 0:
return self._sequential_prediction(list_of_lists_or_source_folder, seg_from_prev_stage_files,
output_filename_truncated, save_probabilities)

data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
seg_from_prev_stage_files,
output_filename_truncated,
num_processes_preprocessing)

return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)

def _load_data_for_prediction(self, input_file, input_seg, properties, preprocessor, plans_manager,
configuration_manager, dataset_json, label_manager):
if properties is not None:
data, seg = preprocessor.run_case_npy(
input_file,
input_seg,
properties,
plans_manager,
configuration_manager,
dataset_json)
else:
data, seg, properties = preprocessor.run_case(
input_file,
input_seg,
plans_manager,
configuration_manager,
dataset_json)

if input_seg is not None:
seg_onehot = convert_labelmap_to_one_hot(input_seg[0], label_manager.foreground_labels, data.dtype)
data = np.vstack((data, seg_onehot))

data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format)
if self.device.type == 'cuda':
data = data.pin_memory()
return data, properties

@torch.inference_mode()
def _sequential_prediction(self, input_list_of_lists, seg_from_prev_stage_files,
output_filename_truncated, save_probabilities):
ret = []
configuration_manager = self.configuration_manager
preprocessor = configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing)
plans_manager = self.plans_manager
dataset_json = self.dataset_json
label_manager = plans_manager.get_label_manager(dataset_json)

for i in range(len(input_list_of_lists)):
ofile = output_filename_truncated[i] if output_filename_truncated is not None else None
if ofile is not None:
print(f'\nPredicting {os.path.basename(ofile)}:')
else:
print(f'\nPredicting image of shape {data.shape}:')

data, properties = self._load_data_for_prediction(input_list_of_lists[i],
seg_from_prev_stage_files[
i] if seg_from_prev_stage_files is not None else None,
None,
preprocessor, plans_manager, configuration_manager,
dataset_json, label_manager)

prediction = self.predict_logits_from_preprocessed_data(data)

if ofile is not None:
print('resampling and export')
export_prediction_from_logits(
prediction, properties, self.configuration_manager, self.plans_manager, self.dataset_json, ofile,
save_probabilities)
print(f'done with {os.path.basename(ofile)}')
else:
print('resampling')
ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(
prediction, self.plans_manager, self.configuration_manager, self.label_manager, properties,
save_probabilities))
print(f'\nDone with image of shape {data.shape}:')

compute_gaussian.cache_clear()
empty_cache(self.device)
return ret

def _internal_get_data_iterator_from_lists_of_filenames(self,
input_list_of_lists: List[List[str]],
seg_from_prev_stage_files: Union[List[str], None],
Expand Down Expand Up @@ -418,6 +492,7 @@ def predict_from_data_iterator(self,
empty_cache(self.device)
return ret

@torch.inference_mode()
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
segmentation_previous_stage: np.ndarray = None,
output_file_truncated: str = None,
Expand All @@ -435,35 +510,29 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
you need to transpose your axes AND your spacing from [x,y,z] to [z,y,x]!
image_properties must only have a 'spacing' key!
"""
ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],
[output_file_truncated],
self.plans_manager, self.dataset_json, self.configuration_manager,
num_threads_in_multithreaded=1, verbose=self.verbose)
if self.verbose:
print('preprocessing')
dct = next(ppa)
data, properties = self._load_data_for_prediction(input_image, segmentation_previous_stage, image_properties,
self.configuration_manager.preprocessor_class(
verbose=self.verbose_preprocessing),
self.plans_manager,
self.configuration_manager,
self.dataset_json,
self.plans_manager.get_label_manager(self.dataset_json))

if self.verbose:
print('predicting')
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()
predicted_logits = self.predict_logits_from_preprocessed_data(data)

if self.verbose:
print('resampling to original shape')
if output_file_truncated is not None:
export_prediction_from_logits(predicted_logits, dct['data_properties'], self.configuration_manager,
export_prediction_from_logits(predicted_logits, properties, self.configuration_manager,
self.plans_manager, self.dataset_json, output_file_truncated,
save_or_return_probabilities)
else:
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
self.configuration_manager,
self.label_manager,
dct['data_properties'],
return_probabilities=
save_or_return_probabilities)
if save_or_return_probabilities:
return ret[0], ret[1]
else:
return ret
return convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
self.configuration_manager,
self.label_manager,
properties,
return_probabilities=
save_or_return_probabilities)

def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -813,6 +882,9 @@ def predict_entry_point():
help="Use this to set the device the inference should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
parser.add_argument('--not_on_device', action='store_true', required=False, default=False,
help="Set this flag to not keep the entire case on device. Recommended for large cases that "
"occupy more VRAM than available")
parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False,
help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive '
'jobs)')
Expand Down Expand Up @@ -853,7 +925,7 @@ def predict_entry_point():
predictor = nnUNetPredictor(tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
perform_everything_on_device=True,
perform_everything_on_device=not args.not_on_device,
device=device,
verbose=args.verbose,
verbose_preprocessing=args.verbose,
Expand Down
Loading