diff --git a/documentation/competitions/Toothfairy2.md b/documentation/competitions/Toothfairy2/Toothfairy2.md similarity index 83% rename from documentation/competitions/Toothfairy2.md rename to documentation/competitions/Toothfairy2/Toothfairy2.md index c91c25904..42fba7dcf 100644 --- a/documentation/competitions/Toothfairy2.md +++ b/documentation/competitions/Toothfairy2/Toothfairy2.md @@ -1,3 +1,12 @@ +Authors: \ +Fabian Isensee*, Yannick Kirchhoff*, Lars Kraemer, Max Rokuss, Constantin Ulrich, Klaus H. Maier-Hein + +*: equal contribution + +Author Affiliations:\ +Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg \ +Helmholtz Imaging + # Introduction This document describes our submission to the [Toothfairy2 Challenge](https://toothfairy2.grand-challenge.org/toothfairy2/). @@ -7,6 +16,8 @@ mirroring and train for 1500 instead of the standard 1000 epochs. Training was e # Dataset Conversion # Experiment Planning and Preprocessing +Adapt and run the [dataset conversion script](../../../nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py). +This script just converts the mhs files to nifti (smaller file size) and removes the unused label ids. ## Extract fingerprint: `nnUNetv2_extract_fingerprint -d 119 -np 48` @@ -160,6 +171,9 @@ Add the following configuration to the generated plans file: Aside from changing the patch size this makes the architecture one stage deeper (one more pooling + res blocks), enabling it to make effective use of the larger input +# Preprocessing +`nnUNetv2_preprocess -d 119 -c 3d_fullres_torchres_ps160x320x320_bs2 -plans_name nnUNetResEncUNetLPlans -np 48` + # Training We train two models on all training cases: @@ -167,6 +181,8 @@ We train two models on all training cases: nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans -tr nnUNetTrainer_onlyMirror01_1500ep nnUNet_results=${nnUNet_results}_2 nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans -tr nnUNetTrainer_onlyMirror01_1500ep ``` +Models are trained from scratch. + Note how in the second line we overwrite the nnUNet_results variable in order to be able to train the same model twice without overwriting the results # Inference @@ -174,4 +190,12 @@ We ensemble the two models from above. On a technical level we copy the two fold directory and rename them to fold_0 and fold_1. This lets us use nnU-Net's cross-validation ensembling strategy which is more computationally efficient (needed for time limit on grand-challenge.org). -Run inference with the inference script \ No newline at end of file +Run inference with the [inference script](inference_script_semseg_only_customInf2.py) + +# Postprocessing +If the prediction of a class on some test case is smaller than the corresponding cutoff size then it is removed +(replaced with background). + +Cutoff values were optimized using a five-fold cross-validation on the Toothfairy2 training data. We optimize HD95 and Dice separately. +The final cutoff for each class is then the smaller value between the two metrics. You can find our volume cutoffs in the inference +script as part of our `postprocess` function. \ No newline at end of file diff --git a/documentation/competitions/Toothfairy2/__init__.py b/documentation/competitions/Toothfairy2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/documentation/competitions/Toothfairy2/inference_script_semseg_only_customInf2.py b/documentation/competitions/Toothfairy2/inference_script_semseg_only_customInf2.py new file mode 100644 index 000000000..488d47354 --- /dev/null +++ b/documentation/competitions/Toothfairy2/inference_script_semseg_only_customInf2.py @@ -0,0 +1,343 @@ +import argparse +import gc +import os +from pathlib import Path +from queue import Queue +from threading import Thread +from typing import Union, Tuple + +import nnunetv2 +import numpy as np +import torch +from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice +from acvl_utils.cropping_and_padding.padding import pad_nd_image +from batchgenerators.utilities.file_and_folder_operations import load_json, join +from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.inference.sliding_window_prediction import compute_gaussian +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager +from torch._dynamo import OptimizedModule +from torch.backends import cudnn +from tqdm import tqdm + + +class CustomPredictor(nnUNetPredictor): + def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, + segmentation_previous_stage: np.ndarray = None): + torch.set_num_threads(7) + with torch.no_grad(): + self.network = self.network.to(self.device) + self.network.eval() + + if self.verbose: + print('preprocessing') + preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose) + data, _ = preprocessor.run_case_npy(input_image, None, image_properties, + self.plans_manager, + self.configuration_manager, + self.dataset_json) + + data = torch.from_numpy(data) + del input_image + if self.verbose: + print('predicting') + + predicted_logits = self.predict_preprocessed_image(data) + + if self.verbose: print('Prediction done') + + segmentation = self.convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, + image_properties) + return segmentation + + def initialize_from_trained_model_folder(self, model_training_output_dir: str, + use_folds: Union[Tuple[Union[int, str]], None], + checkpoint_name: str = 'checkpoint_final.pth'): + """ + This is used when making predictions with a trained model + """ + if use_folds is None: + use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) + + dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) + plans = load_json(join(model_training_output_dir, 'plans.json')) + plans_manager = PlansManager(plans) + + if isinstance(use_folds, str): + use_folds = [use_folds] + + parameters = [] + for i, f in enumerate(use_folds): + f = int(f) if f != 'all' else f + checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), + map_location=torch.device('cpu')) + if i == 0: + trainer_name = checkpoint['trainer_name'] + configuration_name = checkpoint['init_args']['configuration'] + inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ + 'inference_allowed_mirroring_axes' in checkpoint.keys() else None + + parameters.append(join(model_training_output_dir, f'fold_{f}', checkpoint_name)) + + configuration_manager = plans_manager.get_configuration(configuration_name) + # restore network + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, 'nnunetv2.training.nnUNetTrainer') + if trainer_class is None: + raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' + f'Please place it there (in any .py file)!') + network = trainer_class.build_network_architecture( + configuration_manager.network_arch_class_name, + configuration_manager.network_arch_init_kwargs, + configuration_manager.network_arch_init_kwargs_req_import, + num_input_channels, + plans_manager.get_label_manager(dataset_json).num_segmentation_heads, + enable_deep_supervision=False + ) + + self.plans_manager = plans_manager + self.configuration_manager = configuration_manager + self.list_of_parameters = parameters + self.network = network + self.dataset_json = dataset_json + self.trainer_name = trainer_name + self.allowed_mirroring_axes = inference_allowed_mirroring_axes + self.label_manager = plans_manager.get_label_manager(dataset_json) + if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ + and not isinstance(self.network, OptimizedModule): + print('Using torch.compile') + self.network = torch.compile(self.network) + + @torch.inference_mode(mode=True) + def predict_preprocessed_image(self, image): + empty_cache(self.device) + data_device = torch.device('cpu') + predicted_logits_device = torch.device('cpu') + gaussian_device = torch.device('cpu') + compute_device = torch.device('cuda:0') + + data, slicer_revert_padding = pad_nd_image(image, self.configuration_manager.patch_size, + 'constant', {'value': 0}, True, + None) + del image + + slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) + + empty_cache(self.device) + + data = data.to(data_device) + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=predicted_logits_device) + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=10, + device=gaussian_device, dtype=torch.float16) + + if not self.allow_tqdm and self.verbose: + print(f'running prediction: {len(slicers)} steps') + + for p in self.list_of_parameters: + # network weights have to be updated outside autocast! + # we are loading parameters on demand instead of loading them upfront. This reduces memory footprint a lot. + # each set of parameters is only used once on the test set (one image) so run time wise this is almost the + # same + self.network.load_state_dict(torch.load(p, map_location=compute_device)['network_weights']) + with torch.autocast(self.device.type, enabled=True): + for sl in tqdm(slicers, disable=not self.allow_tqdm): + pred = self._internal_maybe_mirror_and_predict(data[sl][None].to(compute_device))[0].to( + predicted_logits_device) + pred /= (pred.max() / 100) + predicted_logits[sl] += (pred * gaussian) + del pred + empty_cache(self.device) + return predicted_logits + + def convert_predicted_logits_to_segmentation_with_correct_shape(self, predicted_logits, props): + old = torch.get_num_threads() + torch.set_num_threads(7) + + # resample to original shape + spacing_transposed = [props['spacing'][i] for i in self.plans_manager.transpose_forward] + current_spacing = self.configuration_manager.spacing if \ + len(self.configuration_manager.spacing) == \ + len(props['shape_after_cropping_and_before_resampling']) else \ + [spacing_transposed[0], *self.configuration_manager.spacing] + predicted_logits = self.configuration_manager.resampling_fn_probabilities(predicted_logits, + props[ + 'shape_after_cropping_and_before_resampling'], + current_spacing, + [props['spacing'][i] for i in + self.plans_manager.transpose_forward]) + + segmentation = None + pp = None + try: + with torch.no_grad(): + pp = predicted_logits.to('cuda:0') + segmentation = pp.argmax(0).cpu() + del pp + except RuntimeError: + del segmentation, pp + torch.cuda.empty_cache() + segmentation = predicted_logits.argmax(0) + del predicted_logits + + # segmentation may be torch.Tensor but we continue with numpy + if isinstance(segmentation, torch.Tensor): + segmentation = segmentation.cpu().numpy() + + # put segmentation in bbox (revert cropping) + segmentation_reverted_cropping = np.zeros(props['shape_before_cropping'], + dtype=np.uint8 if len( + self.label_manager.foreground_labels) < 255 else np.uint16) + slicer = bounding_box_to_slice(props['bbox_used_for_cropping']) + segmentation_reverted_cropping[slicer] = segmentation + del segmentation + + # revert transpose + segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(self.plans_manager.transpose_backward) + torch.set_num_threads(old) + return segmentation_reverted_cropping + + +def predict_semseg(im, prop, semseg_trained_model, semseg_folds): + # initialize predictors + pred_semseg = CustomPredictor( + tile_step_size=0.5, + use_mirroring=True, + use_gaussian=True, + perform_everything_on_device=False, + allow_tqdm=True + ) + pred_semseg.initialize_from_trained_model_folder( + semseg_trained_model, + use_folds=semseg_folds, + checkpoint_name='checkpoint_final.pth' + ) + + semseg_pred = pred_semseg.predict_single_npy_array( + im, prop, None + ) + torch.cuda.empty_cache() + gc.collect() + return semseg_pred + + +def map_labels_to_toothfairy(predicted_seg: np.ndarray) -> np.ndarray: + # Create an array that maps the labels directly + max_label = 42 + mapping = np.arange(max_label + 1) + + # Define the specific remapping + remapping = {19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 26, 25: 27, 26: 28, + 27: 31, 28: 32, 29: 33, 30: 34, 31: 35, 32: 36, 33: 37, 34: 38, + 35: 41, 36: 42, 37: 43, 38: 44, 39: 45, 40: 46, 41: 47, 42: 48} + + # Apply the remapping + for k, v in remapping.items(): + mapping[k] = v + + return mapping[predicted_seg] + + +def postprocess(prediction_npy, vol_per_voxel, verbose: bool = False): + cutoffs = {1: 0.0, + 2: 78411.5, + 3: 0.0, + 4: 0.0, + 5: 2800.0, + 6: 1216.5, + 7: 0.0, + 8: 6222.0, + 9: 1573.0, + 10: 946.0, + 11: 0.0, + 12: 6783.5, + 13: 9469.5, + 14: 0.0, + 15: 2260.0, + 16: 3566.0, + 17: 6321.0, + 18: 4221.5, + 19: 5829.0, + 20: 0.0, + 21: 0.0, + 22: 468.0, + 23: 1555.0, + 24: 1291.5, + 25: 2834.5, + 26: 584.5, + 27: 0.0, + 28: 0.0, + 29: 0.0, + 30: 0.0, + 31: 1935.5, + 32: 0.0, + 33: 0.0, + 34: 6140.0, + 35: 0.0, + 36: 0.0, + 37: 0.0, + 38: 2710.0, + 39: 0.0, + 40: 0.0, + 41: 0.0, + 42: 970.0} + + vol_per_voxel_cutoffs = 0.3 * 0.3 * 0.3 + for c in cutoffs.keys(): + co = cutoffs[c] + if co > 0: + mask = prediction_npy == c + pred_vol = np.sum(mask) * vol_per_voxel + if 0 < pred_vol < (co * vol_per_voxel_cutoffs): + prediction_npy[mask] = 0 + if verbose: + print( + f'removed label {c} because predicted volume of {pred_vol} is less than the cutoff {co * vol_per_voxel_cutoffs}') + return prediction_npy + + +if __name__ == '__main__': + os.environ['nnUNet_compile'] = 'f' + + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input_folder', type=Path, default="/input/images/cbct/") + parser.add_argument('-o', '--output_folder', type=Path, default="/output/images/oral-pharyngeal-segmentation/") + parser.add_argument('-sem_mod', '--semseg_trained_model', type=str, + default="/opt/app/_trained_model/semseg_trained_model") + parser.add_argument('--semseg_folds', type=str, nargs='+', default=[0, 1]) + args = parser.parse_args() + + args.output_folder.mkdir(exist_ok=True, parents=True) + + semseg_folds = [i if i == 'all' else int(i) for i in args.semseg_folds] + semseg_trained_model = args.semseg_trained_model + + rw = SimpleITKIO() + + input_files = list(args.input_folder.glob('*.nii.gz')) + list(args.input_folder.glob('*.mha')) + + for input_fname in input_files: + output_fname = args.output_folder / input_fname.name + + # we start with the instance seg because we can then start converting that while semseg is being predicted + # load test image + im, prop = rw.read_images([input_fname]) + + with torch.no_grad(): + semseg_pred = predict_semseg(im, prop, semseg_trained_model, semseg_folds) + torch.cuda.empty_cache() + gc.collect() + + # now postprocess + semseg_pred = postprocess(semseg_pred, np.prod(prop['spacing']), True) + + semseg_pred = map_labels_to_toothfairy(semseg_pred) + + # now save + rw.write_seg(semseg_pred, output_fname, prop) diff --git a/documentation/competitions/__init__.py b/documentation/competitions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py b/nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py new file mode 100644 index 000000000..7fc3da26a --- /dev/null +++ b/nnunetv2/dataset_conversion/Dataset119_ToothFairy2_All.py @@ -0,0 +1,196 @@ +from typing import Dict, Any +import os +from os.path import join +import json +import random +import multiprocessing + +import SimpleITK as sitk +import numpy as np +from tqdm import tqdm + + +def mapping_DS119() -> Dict[int, int]: + """Remove all NA Classes and make Class IDs continious""" + mapping = {} + mapping.update({i: i for i in range(1, 19)}) # [1-10]->[1-10] | [11-18]->[11-18] + mapping.update({i: i - 2 for i in range(21, 29)}) # [21-28]->[19-26] + mapping.update({i: i - 4 for i in range(31, 39)}) # [31-38]->[27-34] + mapping.update({i: i - 6 for i in range(41, 49)}) # [41-48]->[35-42] + return mapping + + +def mapping_DS120() -> Dict[int, int]: + """Remove Only Keep Teeth and Jaw Classes""" + mapping = {} + mapping.update({i: i for i in range(1, 3)}) # [0-2] -> [0-2] + mapping.update({i: i - 8 for i in range(11, 19)}) # [11-18]->[3-10] + mapping.update({i: i - 10 for i in range(21, 29)}) # [21-28]->[11-18] + mapping.update({i: i - 12 for i in range(31, 39)}) # [31-38]->[19-26] + mapping.update({i: i - 14 for i in range(41, 49)}) # [41-48]->[27-34] + return mapping + + +def mapping_DS121() -> Dict[int, int]: + """Remove Only Keep Teeth and Jaw Classes""" + mapping = {} + mapping.update({i: i - 10 for i in range(11, 19)}) # [11-18]->[3-8] + mapping.update({i: i - 12 for i in range(21, 29)}) # [21-28]->[11-16] + mapping.update({i: i - 14 for i in range(31, 39)}) # [31-38]->[19-24] + mapping.update({i: i - 16 for i in range(41, 49)}) # [41-48]->[27-32] + return mapping + + +def load_json(json_file: str) -> Any: + with open(json_file, "r") as f: + data = json.load(f) + return data + + +def write_json(json_file: str, data: Any, indent: int = 4) -> None: + with open(json_file, "w") as f: + json.dump(data, f, indent=indent) + + +def image_to_nifi(input_path: str, output_path: str) -> None: + image_sitk = sitk.ReadImage(input_path) + sitk.WriteImage(image_sitk, output_path) + + +def label_mapping(input_path: str, output_path: str, mapping: Dict[int, int] = None) -> None: + + label_sitk = sitk.ReadImage(input_path) + if mapping is not None: + label_np = sitk.GetArrayFromImage(label_sitk) + + label_np_new = np.zeros_like(label_np, dtype=np.uint8) + for org_id, new_id in mapping.items(): + label_np_new[label_np == org_id] = new_id + + label_sitk_new = sitk.GetImageFromArray(label_np_new) + label_sitk_new.CopyInformation(label_sitk) + sitk.WriteImage(label_sitk_new, output_path) + else: + sitk.WriteImage(label_sitk, output_path) + + +def process_images(files: str, img_dir_in: str, img_dir_out: str, n_processes: int = 12): + + os.makedirs(img_dir_out, exist_ok=True) + + iterable = [ + { + "input_path": join(img_dir_in, file), + "output_path": join(img_dir_out, file.replace(".mha", ".nii.gz")), + } + for file in files + ] + with multiprocessing.Pool(processes=n_processes) as pool: + jobs = [pool.apply_async(image_to_nifi, kwds={**args}) for args in iterable] + _ = [job.get() for job in tqdm(jobs, desc="Process Images")] + + +def process_labels( + files: str, lbl_dir_in: str, lbl_dir_out: str, mapping: Dict[int, int], n_processes: int = 12 +) -> None: + + os.makedirs(lbl_dir_out, exist_ok=True) + + iterable = [ + { + "input_path": join(lbl_dir_in, file), + "output_path": join(lbl_dir_out, file.replace(".mha", ".nii.gz")), + "mapping": mapping, + } + for file in files + ] + with multiprocessing.Pool(processes=n_processes) as pool: + jobs = [pool.apply_async(label_mapping, kwds={**args}) for args in iterable] + _ = [job.get() for job in tqdm(jobs, desc="Process Labels...")] + + +def process_ds( + root: str, input_ds: str, output_ds: str, mapping: dict, image_link: str = None +) -> None: + os.makedirs(join(root, output_ds), exist_ok=True) + os.makedirs(join(root, output_ds, "labelsTr"), exist_ok=True) + # --- Handle Labels --- # + lbl_files = os.listdir(join(root, input_ds, "labelsTr")) + lbl_dir_in = join(root, input_ds, "labelsTr") + lbl_dir_out = join(root, output_ds, "labelsTr") + + process_labels(lbl_files, lbl_dir_in, lbl_dir_out, mapping, n_processes=12) + + # --- Handle Images --- # + img_files = os.listdir(join(root, input_ds, "imagesTr")) + dataset = {} + if image_link is None: + img_dir_in = join(root, input_ds, "imagesTr") + img_dir_out = join(root, output_ds, "imagesTr") + + process_images(img_files, img_dir_in, img_dir_out, n_processes=12) + else: + base_name = [file.replace("_0000.mha", "") for file in img_files] + for name in base_name: + dataset[name] = { + "images": [join("..", image_link, "imagesTr", name + "_0000.nii.gz")], + "label": join("labelsTr", name + ".nii.gz"), + } + + # --- Generate dataset.json --- # + dataset_json = load_json(join(root, input_ds, "dataset.json")) + dataset_json["file_ending"] = ".nii.gz" + dataset_json["name"] = output_ds + dataset_json["numTraining"] = len(lbl_files) + if dataset != {}: + dataset_json["dataset"] = dataset + + label_dict = dataset_json["labels"] + label_dict_new = {"background": 0} + for k, v in label_dict.items(): + if v in mapping.keys(): + label_dict_new[k] = mapping[v] + dataset_json["labels"] = label_dict_new + write_json(join(root, output_ds, "dataset.json"), dataset_json) + + # --- Generate splits_final.json --- # + img_names = [file.replace("_0000.mha", "") for file in img_files] + + random_seed = 42 + random.seed(random_seed) + random.shuffle(img_names) + + split_index = int(len(img_names) * 0.7) # 70:30 split + train_files = img_names[:split_index] + val_files = img_names[split_index:] + train_files.sort() + val_files.sort() + + split = [{"train": train_files, "val": val_files}] + write_json(join(root, output_ds, "splits_final.json"), split) + + +if __name__ == "__main__": + # Different nnUNet Datasets + # Dataset 112: Raw + # Dataset 119: Replace NaN classes + # Dataset 120: Only Teeth + Jaw Classes + # Dataset 121: Only Teeth Classes + + root = "/media/l727r/data/Teeth_Data/ToothFairy2_Dataset" + + process_ds(root, "Dataset112_ToothFairy2", "Dataset119_ToothFairy2_All", mapping_DS119(), None) + # process_ds( + # root, + # "Dataset112_ToothFairy2", + # "Dataset120_ToothFairy2_JawTeeth", + # mapping_DS120(), + # "Dataset119_ToothFairy2_All", + # ) + # process_ds( + # root, + # "Dataset112_ToothFairy2", + # "Dataset121_ToothFairy2_Teeth", + # mapping_DS121(), + # "Dataset119_ToothFairy2_All", + # )