diff --git a/documentation/how_to_use_nnunet.md b/documentation/how_to_use_nnunet.md index 0cfe07139..290ae8873 100644 --- a/documentation/how_to_use_nnunet.md +++ b/documentation/how_to_use_nnunet.md @@ -290,6 +290,13 @@ from the respective training). You can pick these files from any of the ensemble ## How to run inference with pretrained models See [here](run_inference_with_pretrained_models.md) +## How to Deploy and Run Inference with YOUR Pretrained Models +To facilitate the use of pretrained models on a different computer for inference purposes, follow these streamlined steps: +1. Exporting the Model: Utilize the `nnUNetv2_export_model_to_zip` function to package your trained model into a .zip file. This file will contain all necessary model files. +2. Transferring the Model: Transfer the .zip file to the target computer where inference will be performed. +3. Importing the Model: On the new PC, use the `nnUNetv2_install_pretrained_model_from_zip` to load the pretrained model from the .zip file. +Please note that both computers must have nnU-Net installed along with all its dependencies to ensure compatibility and functionality of the model. + [//]: # (## Examples) [//]: # () diff --git a/documentation/pretraining_and_finetuning.md b/documentation/pretraining_and_finetuning.md index 5360eb74e..5c3f4d0c2 100644 --- a/documentation/pretraining_and_finetuning.md +++ b/documentation/pretraining_and_finetuning.md @@ -2,7 +2,7 @@ ## Intro -So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some source dataset +So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some pretraining dataset and then use the final network weights as initialization for your target dataset. As a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a @@ -16,11 +16,11 @@ how the resulting weights can then be used for initialization. Throughout this README we use the following terminology: -- `source dataset` is the dataset you intend to run the pretraining on +- `pretraining dataset` is the dataset you intend to run the pretraining on (former: source dataset) - `target dataset` is the dataset you are interested in; the one you wish to fine tune on -## Pretraining on the source dataset +## Training on the pretraining dataset In order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are only interested in the target dataset, we first need to run experiment planning (and preprocessing) for it: @@ -29,19 +29,19 @@ only interested in the target dataset, we first need to run experiment planning nnUNetv2_plan_and_preprocess -d TARGET_DATASET ``` -Then we need to extract the dataset fingerprint of the source dataset, if not yet available: +Then we need to extract the dataset fingerprint of the pretraining dataset, if not yet available: ```bash -nnUNetv2_extract_fingerprint -d SOURCE_DATASET +nnUNetv2_extract_fingerprint -d PRETRAINING_DATASET ``` -Now we can take the plans from the target dataset and transfer it to the source: +Now we can take the plans from the target dataset and transfer it to the pretraining dataset: ```bash -nnUNetv2_move_plans_between_datasets -s TARGET_DATASET -t SOURCE_DATASET -sp TARGET_PLANS_IDENTIFIER -tp SOURCE_PLANS_IDENTIFIER +nnUNetv2_move_plans_between_datasets -s PRETRAINING_DATASET -t TARGET_DATASET -sp PRETRAINING_PLANS_IDENTIFIER -tp TARGET_PLANS_IDENTIFIER ``` -`SOURCE_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in +`PRETRAINING_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in nnUNetv2_plan_and_preprocess. For `TARGET_PLANS_IDENTIFIER` we recommend you set something custom in order to not overwrite default plans. @@ -51,16 +51,16 @@ work well (but it could, depending on the schemes!). Note on CT normalization: Yes, also the clip values, mean and std are transferred! -Now you can run the preprocessing on the source task: +Now you can run the preprocessing on the pretraining dataset: ```bash -nnUNetv2_preprocess -d SOURCE_DATSET -plans_name TARGET_PLANS_IDENTIFIER +nnUNetv2_preprocess -d PRETRAINING_DATASET -plans_name TARGET_PLANS_IDENTIFIER ``` And run the training as usual: ```bash -nnUNetv2_train SOURCE_DATSET CONFIG all -p TARGET_PLANS_IDENTIFIER +nnUNetv2_train PRETRAINING_DATASET CONFIG all -p TARGET_PLANS_IDENTIFIER ``` Note how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data. diff --git a/nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py b/nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py new file mode 100644 index 000000000..e8f878088 --- /dev/null +++ b/nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py @@ -0,0 +1,60 @@ +from batchgenerators.utilities.file_and_folder_operations import * +import shutil +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json + + +if __name__ == '__main__': + """ + How to train our submission to the JHU benchmark + + 1. Execute this script here to convert the dataset into nnU-Net format. Adapt the paths to your system! + 2. Run planning and preprocessing: `nnUNetv2_plan_and_preprocess -d 224 -npfp 64 -np 64 -c 3d_fullres -pl + nnUNetPlannerResEncL_torchres`. Adapt the number of processes to your System (-np; -npfp)! Note that each process + will again spawn 4 threads for resampling. This custom planner replaces the nnU-Net default resampling scheme with + a torch-based implementation which is faster but less accurate. This is needed to satisfy the inference speed + constraints. + 3. Run training with `nnUNetv2_train 224 3d_fullres all -p nnUNetResEncUNetLPlans_torchres`. 24GB VRAM required, + training will take ~28-30h. + """ + + + base = '/home/isensee/Downloads/AbdomenAtlas1.0Mini' + cases = subdirs(base, join=False, prefix='BDMAP') + + target_dataset_id = 224 + target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AbdomenAtlas1.0' + + raw_dir = '/home/isensee/drives/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/2024_JHU_benchmark' + maybe_mkdir_p(join(raw_dir, target_dataset_name)) + imagesTr = join(raw_dir, target_dataset_name, 'imagesTr') + labelsTr = join(raw_dir, target_dataset_name, 'labelsTr') + maybe_mkdir_p(imagesTr) + maybe_mkdir_p(labelsTr) + + for case in cases: + shutil.copy(join(base, case, 'ct.nii.gz'), join(imagesTr, case + '_0000.nii.gz')) + shutil.copy(join(base, case, 'combined_labels.nii.gz'), join(labelsTr, case + '.nii.gz')) + + labels = { + "background": 0, + "aorta": 1, + "gall_bladder": 2, + "kidney_left": 3, + "kidney_right": 4, + "liver": 5, + "pancreas": 6, + "postcava": 7, + "spleen": 8, + "stomach": 9 + } + + generate_dataset_json( + join(raw_dir, target_dataset_name), + {0: 'nonCT'}, # this was a mistake we did at the beginning and we keep it like that here for consistency + labels, + len(cases), + '.nii.gz', + None, + target_dataset_name, + overwrite_image_reader_writer='NibabelIOWithReorient' + ) \ No newline at end of file diff --git a/nnunetv2/evaluation/evaluate_predictions.py b/nnunetv2/evaluation/evaluate_predictions.py index a7af531e4..1ecbf3255 100644 --- a/nnunetv2/evaluation/evaluate_predictions.py +++ b/nnunetv2/evaluation/evaluate_predictions.py @@ -92,7 +92,6 @@ def compute_metrics(reference_file: str, prediction_file: str, image_reader_writ # load images seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file) seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file) - # spacing = seg_ref_dict['spacing'] ignore_mask = seg_ref == ignore_label if ignore_label is not None else None diff --git a/nnunetv2/evaluation/find_best_configuration.py b/nnunetv2/evaluation/find_best_configuration.py index 7e9f77420..f585b80d9 100644 --- a/nnunetv2/evaluation/find_best_configuration.py +++ b/nnunetv2/evaluation/find_best_configuration.py @@ -3,8 +3,9 @@ from copy import deepcopy from typing import Union, List, Tuple -from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, save_json - +from batchgenerators.utilities.file_and_folder_operations import ( + load_json, join, isdir, listdir, save_json +) from nnunetv2.configuration import default_num_processes from nnunetv2.ensembling.ensemble import ensemble_crossvalidations from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results @@ -320,6 +321,11 @@ def accumulate_crossval_results_entry_point(): merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}') else: merged_output_folder = args.o + if isdir(merged_output_folder) and len(listdir(merged_output_folder)) > 0: + raise FileExistsError( + f"Output folder {merged_output_folder} exists and is not empty. " + f"To avoid data loss, nnUNet requires an empty output folder." + ) accumulate_cv_results(trained_model_folder, merged_output_folder, args.f) diff --git a/nnunetv2/experiment_planning/experiment_planners/resampling/__init__.py b/nnunetv2/experiment_planning/experiment_planners/resampling/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py b/nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py new file mode 100644 index 000000000..ee5adc7b9 --- /dev/null +++ b/nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py @@ -0,0 +1,181 @@ +from typing import Union, List, Tuple + +from nnunetv2.configuration import ANISO_THRESHOLD +from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner +from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \ + nnUNetPlannerResEncL +from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet + + +class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 24, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + return self.plans_identifier + '_' + configuration_name + + def determine_resampling(self, *args, **kwargs): + """ + returns what functions to use for resampling data and seg, respectively. Also returns kwargs + resampling function must be callable(data, current_spacing, new_spacing, **kwargs) + + determine_resampling is called within get_plans_for_configuration to allow for different functions for each + configuration + """ + resampling_data = resample_torch_fornnunet + resampling_data_kwargs = { + "is_seg": False, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + resampling_seg = resample_torch_fornnunet + resampling_seg_kwargs = { + "is_seg": True, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs + + def determine_segmentation_softmax_export_fn(self, *args, **kwargs): + """ + function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be + used as target. current_spacing and new_spacing are merely there in case we want to use it somehow + + determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different + functions for each configuration + + """ + resampling_fn = resample_torch_fornnunet + resampling_fn_kwargs = { + "is_seg": False, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + return resampling_fn, resampling_fn_kwargs + + +class nnUNetPlannerResEncL_torchres_sepz(nnUNetPlannerResEncL): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 24, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres_sepz', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + return self.plans_identifier + '_' + configuration_name + + def determine_resampling(self, *args, **kwargs): + """ + returns what functions to use for resampling data and seg, respectively. Also returns kwargs + resampling function must be callable(data, current_spacing, new_spacing, **kwargs) + + determine_resampling is called within get_plans_for_configuration to allow for different functions for each + configuration + """ + resampling_data = resample_torch_fornnunet + resampling_data_kwargs = { + "is_seg": False, + 'force_separate_z': None, + 'memefficient_seg_resampling': False, + 'separate_z_anisotropy_threshold': ANISO_THRESHOLD + } + resampling_seg = resample_torch_fornnunet + resampling_seg_kwargs = { + "is_seg": True, + 'force_separate_z': None, + 'memefficient_seg_resampling': False, + 'separate_z_anisotropy_threshold': ANISO_THRESHOLD + } + return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs + + def determine_segmentation_softmax_export_fn(self, *args, **kwargs): + """ + function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be + used as target. current_spacing and new_spacing are merely there in case we want to use it somehow + + determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different + functions for each configuration + + """ + resampling_fn = resample_torch_fornnunet + resampling_fn_kwargs = { + "is_seg": False, + 'force_separate_z': None, + 'memefficient_seg_resampling': False, + 'separate_z_anisotropy_threshold': ANISO_THRESHOLD + } + return resampling_fn, resampling_fn_kwargs + + +class nnUNetPlanner_torchres(ExperimentPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans_torchres', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but different plans file can have configurations with the + same name. In order to distinguish the associated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + return self.plans_identifier + '_' + configuration_name + + def determine_resampling(self, *args, **kwargs): + """ + returns what functions to use for resampling data and seg, respectively. Also returns kwargs + resampling function must be callable(data, current_spacing, new_spacing, **kwargs) + + determine_resampling is called within get_plans_for_configuration to allow for different functions for each + configuration + """ + resampling_data = resample_torch_fornnunet + resampling_data_kwargs = { + "is_seg": False, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + resampling_seg = resample_torch_fornnunet + resampling_seg_kwargs = { + "is_seg": True, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs + + def determine_segmentation_softmax_export_fn(self, *args, **kwargs): + """ + function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be + used as target. current_spacing and new_spacing are merely there in case we want to use it somehow + + determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different + functions for each configuration + + """ + resampling_fn = resample_torch_fornnunet + resampling_fn_kwargs = { + "is_seg": False, + 'force_separate_z': False, + 'memefficient_seg_resampling': False + } + return resampling_fn, resampling_fn_kwargs diff --git a/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py b/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py index 1552e2067..012950b82 100644 --- a/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py +++ b/nnunetv2/experiment_planning/experiment_planners/residual_unets/residual_encoder_unet_planners.py @@ -6,6 +6,7 @@ from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm +from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet from torch import nn from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner diff --git a/nnunetv2/imageio/nibabel_reader_writer.py b/nnunetv2/imageio/nibabel_reader_writer.py index 78fb17ac1..2854da4b5 100644 --- a/nnunetv2/imageio/nibabel_reader_writer.py +++ b/nnunetv2/imageio/nibabel_reader_writer.py @@ -31,8 +31,6 @@ class NibabelIO(BaseReaderWriter): supported_file_endings = [ '.nii', '.nii.gz', - '.nrrd', - '.mha' ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: @@ -110,8 +108,6 @@ class NibabelIOWithReorient(BaseReaderWriter): supported_file_endings = [ '.nii', '.nii.gz', - '.nrrd', - '.mha' ] def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: diff --git a/nnunetv2/inference/JHU_inference.py b/nnunetv2/inference/JHU_inference.py index d57c2a606..0933600a9 100644 --- a/nnunetv2/inference/JHU_inference.py +++ b/nnunetv2/inference/JHU_inference.py @@ -176,7 +176,7 @@ def predict_from_data_iterator(self, predictor.initialize_from_trained_model_folder( args.model, ('all', ), - 'checkpoint_latest.pth' + 'checkpoint_final.pth' ) # we need to create list of list of input files diff --git a/nnunetv2/inference/export_prediction.py b/nnunetv2/inference/export_prediction.py index 33035676b..f5cdb958d 100644 --- a/nnunetv2/inference/export_prediction.py +++ b/nnunetv2/inference/export_prediction.py @@ -23,14 +23,15 @@ def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits torch.set_num_threads(num_threads_torch) # resample to original shape + spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward] current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ - [properties_dict['spacing'][0], *configuration_manager.spacing] + [spacing_transposed[0], *configuration_manager.spacing] predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, properties_dict['shape_after_cropping_and_before_resampling'], current_spacing, - properties_dict['spacing']) + [properties_dict['spacing'][i] for i in plans_manager.transpose_forward]) # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because # apply_inference_nonlin will convert to torch predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) @@ -123,13 +124,14 @@ def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) + spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward] # resample to original shape current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \ - [properties_dict['spacing'][0], *configuration_manager.spacing] + [spacing_transposed[0], *configuration_manager.spacing] target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ - [properties_dict['spacing'][0], *configuration_manager.spacing] + [spacing_transposed[0], *configuration_manager.spacing] predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted, target_shape, current_spacing, diff --git a/nnunetv2/preprocessing/preprocessors/default_preprocessor.py b/nnunetv2/preprocessing/preprocessors/default_preprocessor.py index 7e0068b9d..8b1abf7b2 100644 --- a/nnunetv2/preprocessing/preprocessors/default_preprocessor.py +++ b/nnunetv2/preprocessing/preprocessors/default_preprocessor.py @@ -230,15 +230,17 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan # multiprocessing magic. r = [] with multiprocessing.get_context("spawn").Pool(num_processes) as p: + remaining = list(range(len(dataset))) + # p is pretty nifti. If we kill workers they just respawn but don't do any work. + # So we need to store the original pool of workers. + workers = [j for j in p._pool] + for k in dataset.keys(): r.append(p.starmap_async(self.run_case_save, ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'], plans_manager, configuration_manager, dataset_json),))) - remaining = list(range(len(dataset))) - # p is pretty nifti. If we kill workers they just respawn but don't do any work. - # So we need to store the original pool of workers. - workers = [j for j in p._pool] + with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar: while len(remaining) > 0: all_alive = all([j.is_alive() for j in workers]) @@ -251,6 +253,8 @@ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plan 'an error message, out of RAM is likely the problem. In that case ' 'reducing the number of workers might help') done = [i for i in remaining if r[i].ready()] + # get done so that errors can be raised + _ = [r[i].get() for i in done] for _ in done: r[_].get() # allows triggering errors pbar.update() diff --git a/nnunetv2/preprocessing/resampling/default_resampling.py b/nnunetv2/preprocessing/resampling/default_resampling.py index d205be249..40408de27 100644 --- a/nnunetv2/preprocessing/resampling/default_resampling.py +++ b/nnunetv2/preprocessing/resampling/default_resampling.py @@ -1,11 +1,13 @@ from collections import OrderedDict +from copy import deepcopy from typing import Union, Tuple, List import numpy as np import pandas as pd +import sklearn import torch from batchgenerators.augmentations.utils import resize_segmentation -from scipy.ndimage.interpolation import map_coordinates +from scipy.ndimage import map_coordinates from skimage.transform import resize from nnunetv2.configuration import ANISO_THRESHOLD @@ -29,7 +31,11 @@ def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray], return new_shape -def determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): +def determine_do_sep_z_and_axis( + force_separate_z: bool, + current_spacing, + new_spacing, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]: if force_separate_z is not None: do_separate_z = force_separate_z if force_separate_z: @@ -50,12 +56,14 @@ def determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, if axis is not None: if len(axis) == 3: do_separate_z = False + axis = None elif len(axis) == 2: # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample # separately in the out of plane axis do_separate_z = False + axis = None else: - pass + axis = axis[0] return do_separate_z, axis @@ -135,8 +143,7 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L data = data.astype(float, copy=False) if do_separate_z: # print("separate z, order in z is", order_z, "order inplane is", order) - assert len(axis) == 1, "only one anisotropic axis supported" - axis = axis[0] + assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic' if axis == 0: new_shape_2d = new_shape[1:] elif axis == 1: @@ -145,20 +152,23 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L new_shape_2d = new_shape[:-1] for c in range(data.shape[0]): - reshaped_here = np.zeros((data.shape[1], *new_shape_2d)) + tmp = deepcopy(new_shape) + tmp[axis] = shape[axis] + reshaped_here = np.zeros(tmp) for slice_id in range(shape[axis]): if axis == 0: reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs) elif axis == 1: - reshaped_here[slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs) + reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs) else: - reshaped_here[slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs) + reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs) if shape[axis] != new_shape[axis]: # The following few lines are blatantly copied and modified from sklearn's resize() rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] orig_rows, orig_cols, orig_dim = reshaped_here.shape + # align_corners=False row_scale = float(orig_rows) / rows col_scale = float(orig_cols) / cols dim_scale = float(orig_dim) / dim @@ -187,3 +197,10 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L else: # print("no resampling necessary") return data + + +if __name__ == '__main__': + input_array = np.random.random((1, 42, 231, 142)) + output_shape = (52, 256, 256) + out = resample_data_or_seg(input_array, output_shape, is_seg=False, axis=3, order=1, order_z=0, do_separate_z=True) + print(out.shape, input_array.shape) diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index d3803f391..b23847cb2 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -238,13 +238,24 @@ def initialize(self): def _do_i_compile(self): # new default: compile is enabled! + # compile does not work on mps + if self.device == torch.device('mps'): + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because of unsupported mps device") + return False + # CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable if self.device == torch.device('cpu'): + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because device is CPU") return False # default torch.compile doesn't work on windows because there are apparently no triton wheels for it # https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2 if os.name == 'nt': + if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'): + self.print_to_log_file("INFO: torch.compile disabled because Windows is not natively supported. If " + "you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2") return False if 'nnUNet_compile' not in os.environ.keys():