Skip to content

Commit

Permalink
Merge branch 'MIC-DKFZ:master' into seq-inf
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril authored Jun 10, 2024
2 parents fef5a45 + 75c46fe commit 4a92224
Show file tree
Hide file tree
Showing 14 changed files with 319 additions and 35 deletions.
7 changes: 7 additions & 0 deletions documentation/how_to_use_nnunet.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ from the respective training). You can pick these files from any of the ensemble
## How to run inference with pretrained models
See [here](run_inference_with_pretrained_models.md)

## How to Deploy and Run Inference with YOUR Pretrained Models
To facilitate the use of pretrained models on a different computer for inference purposes, follow these streamlined steps:
1. Exporting the Model: Utilize the `nnUNetv2_export_model_to_zip` function to package your trained model into a .zip file. This file will contain all necessary model files.
2. Transferring the Model: Transfer the .zip file to the target computer where inference will be performed.
3. Importing the Model: On the new PC, use the `nnUNetv2_install_pretrained_model_from_zip` to load the pretrained model from the .zip file.
Please note that both computers must have nnU-Net installed along with all its dependencies to ensure compatibility and functionality of the model.

[//]: # (## Examples)

[//]: # ()
Expand Down
22 changes: 11 additions & 11 deletions documentation/pretraining_and_finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Intro

So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some source dataset
So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some pretraining dataset
and then use the final network weights as initialization for your target dataset.

As a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a
Expand All @@ -16,11 +16,11 @@ how the resulting weights can then be used for initialization.

Throughout this README we use the following terminology:

- `source dataset` is the dataset you intend to run the pretraining on
- `pretraining dataset` is the dataset you intend to run the pretraining on (former: source dataset)
- `target dataset` is the dataset you are interested in; the one you wish to fine tune on


## Pretraining on the source dataset
## Training on the pretraining dataset

In order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are
only interested in the target dataset, we first need to run experiment planning (and preprocessing) for it:
Expand All @@ -29,19 +29,19 @@ only interested in the target dataset, we first need to run experiment planning
nnUNetv2_plan_and_preprocess -d TARGET_DATASET
```

Then we need to extract the dataset fingerprint of the source dataset, if not yet available:
Then we need to extract the dataset fingerprint of the pretraining dataset, if not yet available:

```bash
nnUNetv2_extract_fingerprint -d SOURCE_DATASET
nnUNetv2_extract_fingerprint -d PRETRAINING_DATASET
```

Now we can take the plans from the target dataset and transfer it to the source:
Now we can take the plans from the target dataset and transfer it to the pretraining dataset:

```bash
nnUNetv2_move_plans_between_datasets -s TARGET_DATASET -t SOURCE_DATASET -sp TARGET_PLANS_IDENTIFIER -tp SOURCE_PLANS_IDENTIFIER
nnUNetv2_move_plans_between_datasets -s PRETRAINING_DATASET -t TARGET_DATASET -sp PRETRAINING_PLANS_IDENTIFIER -tp TARGET_PLANS_IDENTIFIER
```

`SOURCE_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in
`PRETRAINING_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in
nnUNetv2_plan_and_preprocess. For `TARGET_PLANS_IDENTIFIER` we recommend you set something custom in order to not
overwrite default plans.

Expand All @@ -51,16 +51,16 @@ work well (but it could, depending on the schemes!).

Note on CT normalization: Yes, also the clip values, mean and std are transferred!

Now you can run the preprocessing on the source task:
Now you can run the preprocessing on the pretraining dataset:

```bash
nnUNetv2_preprocess -d SOURCE_DATSET -plans_name TARGET_PLANS_IDENTIFIER
nnUNetv2_preprocess -d PRETRAINING_DATASET -plans_name TARGET_PLANS_IDENTIFIER
```

And run the training as usual:

```bash
nnUNetv2_train SOURCE_DATSET CONFIG all -p TARGET_PLANS_IDENTIFIER
nnUNetv2_train PRETRAINING_DATASET CONFIG all -p TARGET_PLANS_IDENTIFIER
```

Note how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data.
Expand Down
60 changes: 60 additions & 0 deletions nnunetv2/dataset_conversion/Dataset224_AbdomenAtlas1.0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from batchgenerators.utilities.file_and_folder_operations import *
import shutil
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json


if __name__ == '__main__':
"""
How to train our submission to the JHU benchmark
1. Execute this script here to convert the dataset into nnU-Net format. Adapt the paths to your system!
2. Run planning and preprocessing: `nnUNetv2_plan_and_preprocess -d 224 -npfp 64 -np 64 -c 3d_fullres -pl
nnUNetPlannerResEncL_torchres`. Adapt the number of processes to your System (-np; -npfp)! Note that each process
will again spawn 4 threads for resampling. This custom planner replaces the nnU-Net default resampling scheme with
a torch-based implementation which is faster but less accurate. This is needed to satisfy the inference speed
constraints.
3. Run training with `nnUNetv2_train 224 3d_fullres all -p nnUNetResEncUNetLPlans_torchres`. 24GB VRAM required,
training will take ~28-30h.
"""


base = '/home/isensee/Downloads/AbdomenAtlas1.0Mini'
cases = subdirs(base, join=False, prefix='BDMAP')

target_dataset_id = 224
target_dataset_name = f'Dataset{target_dataset_id:3.0f}_AbdomenAtlas1.0'

raw_dir = '/home/isensee/drives/E132-Projekte/Projects/Helmholtz_Imaging_ACVL/2024_JHU_benchmark'
maybe_mkdir_p(join(raw_dir, target_dataset_name))
imagesTr = join(raw_dir, target_dataset_name, 'imagesTr')
labelsTr = join(raw_dir, target_dataset_name, 'labelsTr')
maybe_mkdir_p(imagesTr)
maybe_mkdir_p(labelsTr)

for case in cases:
shutil.copy(join(base, case, 'ct.nii.gz'), join(imagesTr, case + '_0000.nii.gz'))
shutil.copy(join(base, case, 'combined_labels.nii.gz'), join(labelsTr, case + '.nii.gz'))

labels = {
"background": 0,
"aorta": 1,
"gall_bladder": 2,
"kidney_left": 3,
"kidney_right": 4,
"liver": 5,
"pancreas": 6,
"postcava": 7,
"spleen": 8,
"stomach": 9
}

generate_dataset_json(
join(raw_dir, target_dataset_name),
{0: 'nonCT'}, # this was a mistake we did at the beginning and we keep it like that here for consistency
labels,
len(cases),
'.nii.gz',
None,
target_dataset_name,
overwrite_image_reader_writer='NibabelIOWithReorient'
)
1 change: 0 additions & 1 deletion nnunetv2/evaluation/evaluate_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def compute_metrics(reference_file: str, prediction_file: str, image_reader_writ
# load images
seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file)
seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file)
# spacing = seg_ref_dict['spacing']

ignore_mask = seg_ref == ignore_label if ignore_label is not None else None

Expand Down
10 changes: 8 additions & 2 deletions nnunetv2/evaluation/find_best_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from copy import deepcopy
from typing import Union, List, Tuple

from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, save_json

from batchgenerators.utilities.file_and_folder_operations import (
load_json, join, isdir, listdir, save_json
)
from nnunetv2.configuration import default_num_processes
from nnunetv2.ensembling.ensemble import ensemble_crossvalidations
from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
Expand Down Expand Up @@ -320,6 +321,11 @@ def accumulate_crossval_results_entry_point():
merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}')
else:
merged_output_folder = args.o
if isdir(merged_output_folder) and len(listdir(merged_output_folder)) > 0:
raise FileExistsError(
f"Output folder {merged_output_folder} exists and is not empty. "
f"To avoid data loss, nnUNet requires an empty output folder."
)

accumulate_cv_results(trained_model_folder, merged_output_folder, args.f)

Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from typing import Union, List, Tuple

from nnunetv2.configuration import ANISO_THRESHOLD
from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
from nnunetv2.experiment_planning.experiment_planners.residual_unets.residual_encoder_unet_planners import \
nnUNetPlannerResEncL
from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet


class nnUNetPlannerResEncL_torchres(nnUNetPlannerResEncL):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)

def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name

def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs

def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_fn, resampling_fn_kwargs


class nnUNetPlannerResEncL_torchres_sepz(nnUNetPlannerResEncL):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 24,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetLPlans_torchres_sepz',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)

def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name

def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs

def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': None,
'memefficient_seg_resampling': False,
'separate_z_anisotropy_threshold': ANISO_THRESHOLD
}
return resampling_fn, resampling_fn_kwargs


class nnUNetPlanner_torchres(ExperimentPlanner):
def __init__(self, dataset_name_or_id: Union[str, int],
gpu_memory_target_in_gb: float = 8,
preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans_torchres',
overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
suppress_transpose: bool = False):
super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
overwrite_target_spacing, suppress_transpose)

def generate_data_identifier(self, configuration_name: str) -> str:
"""
configurations are unique within each plans file but different plans file can have configurations with the
same name. In order to distinguish the associated data we need a data identifier that reflects not just the
config but also the plans it originates from
"""
return self.plans_identifier + '_' + configuration_name

def determine_resampling(self, *args, **kwargs):
"""
returns what functions to use for resampling data and seg, respectively. Also returns kwargs
resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
determine_resampling is called within get_plans_for_configuration to allow for different functions for each
configuration
"""
resampling_data = resample_torch_fornnunet
resampling_data_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
resampling_seg = resample_torch_fornnunet
resampling_seg_kwargs = {
"is_seg": True,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs

def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
"""
function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
functions for each configuration
"""
resampling_fn = resample_torch_fornnunet
resampling_fn_kwargs = {
"is_seg": False,
'force_separate_z': False,
'memefficient_seg_resampling': False
}
return resampling_fn, resampling_fn_kwargs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm
from nnunetv2.preprocessing.resampling.resample_torch import resample_torch_fornnunet
from torch import nn

from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
Expand Down
4 changes: 0 additions & 4 deletions nnunetv2/imageio/nibabel_reader_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ class NibabelIO(BaseReaderWriter):
supported_file_endings = [
'.nii',
'.nii.gz',
'.nrrd',
'.mha'
]

def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
Expand Down Expand Up @@ -110,8 +108,6 @@ class NibabelIOWithReorient(BaseReaderWriter):
supported_file_endings = [
'.nii',
'.nii.gz',
'.nrrd',
'.mha'
]

def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
Expand Down
2 changes: 1 addition & 1 deletion nnunetv2/inference/JHU_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def predict_from_data_iterator(self,
predictor.initialize_from_trained_model_folder(
args.model,
('all', ),
'checkpoint_latest.pth'
'checkpoint_final.pth'
)

# we need to create list of list of input files
Expand Down
Loading

0 comments on commit 4a92224

Please sign in to comment.